In [None]:
import sys

from google.colab import drive
drive.mount('/content/drive')
sys.path.append('/content/drive/MyDrive/DeepLCMS/train_google_colab')

In [None]:
import train_NN, colab_utils, colab_functions, prepare_data
import os
import pytorch_lightning as pl

In [None]:
!unzip -q experiment.zip


# Import and install libraries

In [None]:
import torch
import torchvision
import importlib

from torch import nn
from torch.autograd import Variable
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
from torchvision.transforms import ToTensor
from torchvision.utils import make_grid
from tqdm.auto import tqdm

In [None]:
%%capture
!pip install lightning
!pip install timm
import timm

In [None]:
if int(torchvision.__version__.split(sep=".")[1]) < 13:
    !conda uninstall pytorch
    !pip uninstall torch --yes
    !pip uninstall torch --yes# run this command twice

    !conda uninstall torchvision
    !pip uninstall torchvision --yes
    !pip uninstall torchvision --yes # run this command twice

    !conda install --yes pytorch torchvision
    import torch
    import torchvision

    print(f"Current version of torch: {torch.__version__}")
    print(f"Current version of torchvision: {torchvision.__version__}")

else:
    import torch
    import torchvision

    print(f"Current version of torch: {torch.__version__}")
    print(f"Current version of torchvision: {torchvision.__version__}")

In [None]:
if importlib.util.find_spec("torchinfo") is None:
    print("torchinfo" + " is not installed")
    !pip install torchinfo
    import torchinfo
    from tqdm.auto import tqdm
else:
    import torchinfo
    from tqdm.auto import tqdm

# Check if GPU is used

In [None]:
device = colab_functions.get_device()


In [None]:
timm.list_models("resnet*", pretrained=True)

# Training Recipe


In [None]:
LitModel = train_NN.LitModel()
train_NN.show_architecture(LitModel)

In [None]:
example_model = train_NN.PretrainedModelEvaluator("resnet10t.c3_in1k")

In [None]:
preprocess_train, preprocess_val, preprocess_test =  prepare_data.get_timm_transforms(example_model)

In [None]:
train_dataloader, val_dataloader, test_dataloader = prepare_data.get_dataloaders(preprocess_train = preprocess_train,
                                                                                 preprocess_val = preprocess_val,
                                                                                 preprocess_test = preprocess_test
                                                                                 )

In [None]:
prepare_data.inspect_dataloader(train_dataloader)

In [None]:
%reload_ext tensorboard
%tensorboard --logdir='/content/lightning_logs'

# Train model

In [None]:
# Set the CUDA_VISIBLE_DEVICES environment variable
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

metrics_callback = train_NN.MetricsCallback()

trainer = pl.Trainer(max_epochs=2, callbacks=[metrics_callback], log_every_n_steps=1)
trainer.fit(
    model=example_model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader
)

In [None]:
# https://github.com/frgfm/torch-cam

# Evaluate the test set


In [None]:
preprocess_test = timm.data.create_transform(**data_cfg, is_training=False)

test_data = datasets.ImageFolder(
    root=test_dir,
    transform=preprocess_test,
    target_transform=None,
)

test_dataloader = DataLoader(
    test_data,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    drop_last=False,
    pin_memory=True,
)

model.eval()
predictions = trainer.predict(model, test_dataloader)

In [None]:
all_labels = torch.tensor(test_dataloader.dataset.targets)
all_labels

In [None]:
probabilities = torch.sigmoid((torch.cat(predictions, dim=0)))

# Threshold probabilities to get binary predictions (0 or 1)
threshold = 0.5
binary_predictions = (probabilities > threshold).float().view(-1)
binary_predictions

In [None]:
acc = (all_labels == binary_predictions).sum().item() / len(all_labels)


metric_f1 = BinaryF1Score()
f1 = metric_f1(all_labels, binary_predictions)


bcm = BinaryConfusionMatrix()
bcm(all_labels, binary_predictions)
fig_, ax_ = bcm.plot()