In [None]:
import sys

from google.colab import drive
drive.mount('/content/drive')
sys.path.append('/content/drive/MyDrive/DeepLCMS/train_google_colab')

In [None]:
import train_NN, colab_utils, colab_functions

In [None]:
!unzip -q experiment.zip


# Import and install libraries

In [None]:
import torch
import torchvision

from torch import nn
from torch.autograd import Variable
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
from torchvision.transforms import ToTensor
from torchvision.utils import make_grid
from tqdm.auto import tqdm

In [None]:
!pip install lightning
!pip install timm
import timm

In [None]:
if int(torchvision.__version__.split(sep=".")[1]) < 13:
    !conda uninstall pytorch
    !pip uninstall torch --yes
    !pip uninstall torch --yes# run this command twice

    !conda uninstall torchvision
    !pip uninstall torchvision --yes
    !pip uninstall torchvision --yes # run this command twice

    !conda install --yes pytorch torchvision
    import torch
    import torchvision

    print(f"Current version of torch: {torch.__version__}")
    print(f"Current version of torchvision: {torchvision.__version__}")

else:
    import torch
    import torchvision

    print(f"Current version of torch: {torch.__version__}")
    print(f"Current version of torchvision: {torchvision.__version__}")

In [None]:
if importlib.util.find_spec("torchinfo") is None:
    print("torchinfo" + " is not installed")
    !pip install torchinfo
    from torchinfo import summary
    from tqdm.auto import tqdm
else:
    from torchinfo import summary
    from tqdm.auto import tqdm

# Check if GPU is used

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
    print(torch.cuda.get_device_name(0))
device

# Unzip folder

In [None]:
# !unzip -q experiment.zip

List available pretrained models:

In [None]:
timm.list_models("resnet*", pretrained=True)

# Build model convnext_xlarge_in22ft1k

In [None]:
import pytorch_lightning as pl
import torch.nn.functional as F
from pytorch_lightning.callbacks import Callback
from timm import create_model
from torchmetrics import Accuracy
from torchmetrics.classification import BinaryAUROC, BinaryF1Score, BinaryConfusionMatrix


class MetricsCallback(Callback):
    def __init__(self):
        super().__init__()
        self.metrics = []

    def on_validation_end(self, trainer, pl_module):
        self.metrics.append(trainer.logged_metrics)


class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = create_model(
            "resnet14t.c3_in1k", pretrained=True, num_classes=1
        )

        # Freeze all layers except for the last one
        for param in self.model.parameters():
            param.requires_grad = False

        self.model.fc = torch.nn.Sequential(
            torch.nn.Linear(in_features=2048, out_features=512, bias=True),
            torch.nn.ReLU(),
            torch.nn.Dropout(p=0.3),
            torch.nn.Linear(in_features=512, out_features=256, bias=True),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=256, out_features=1, bias=True),
        )
        # self.model.fc = torch.nn.Sequential(
        #     torch.nn.Linear(in_features=512, out_features=256, bias=True),
        #     torch.nn.ReLU(),
        #     torch.nn.Dropout(p=0.3),
        #     torch.nn.Linear(in_features=256, out_features=128, bias=True),
        #     torch.nn.ReLU(),
        #     torch.nn.Linear(in_features=128, out_features=1, bias=True),
        # )


    def forward(self, x):
        x = self.model(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch

        loss_fn = nn.BCELoss()

        y_pred_logits = self(x).squeeze()
        y_pred = torch.sigmoid(y_pred_logits)
        loss = loss_fn(y_pred, y.float())
        self.log("train_loss", loss)

        # Calculate metrics
        y_pred_class = torch.round(y_pred)
        acc = (y_pred_class == y).sum().item() / len(y_pred)
        self.log("train_acc", acc)

        metric_f1 = BinaryF1Score().to(y.device)
        f1 = metric_f1(y_pred_class, y)
        self.log("train_f1", f1)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch

        loss_fn = nn.BCELoss()

        y_pred_logits = self(x).squeeze()
        y_pred = torch.sigmoid(y_pred_logits)
        loss = loss_fn(y_pred, y.float())
        self.log("val_loss", loss)

        # Calculate metrics
        y_pred_class = torch.round(y_pred)
        acc = (y_pred_class == y).sum().item() / len(y_pred)
        self.log("val_acc", acc)

        metric_f1 = BinaryF1Score().to(y.device)
        f1 = metric_f1(y_pred_class, y)
        self.log("val_f1", f1)

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
      if isinstance(batch, list):
          # Assuming the first element in the list is the input tensor
          input_tensor = batch[0]
          return self(input_tensor)
      else:
          # If batch is already a tensor, proceed as usual
          print("Input Shape:", batch.shape)
          return self(batch)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=0.001,
            weight_decay=2e-5,
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=20, eta_min=0
        )
        return [optimizer], [scheduler]

In [None]:
model = LitModel()

summary(
    model=model,
    input_size=(32, 3, 384, 384),
    col_names=["input_size", "output_size", "num_params", "trainable"],
    col_width=20,
    row_settings=["var_names"],
)

In [None]:
# Resolve data configuration for the model
data_cfg = timm.data.resolve_data_config(model.model.default_cfg)

preprocess_val = timm.data.create_transform(**data_cfg, is_training=False)

# Create the transform object
preprocess_train = timm.data.create_transform(
    **data_cfg,
    is_training=False,
    # no_aug = True,
    # re_prob=0.1,
    # re_mode="pixel",
    # auto_augment="rand-m1-mstd0.1-inc0",
)

In [None]:
BATCH_SIZE = 32
NUM_WORKERS = os.cpu_count()

train_data = datasets.ImageFolder(
    root=train_dir,
    transform=preprocess_train,
    target_transform=None,
)

val_data = datasets.ImageFolder(
    root=train_dir,
    transform=preprocess_val,
    target_transform=None,
)

train_dataloader = DataLoader(
    train_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

val_dataloader = DataLoader(
    val_data,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

In [None]:
images, labels = next(iter(train_dataloader))
grid = make_grid(images)

plt.figure(figsize=(15, 25))

img = plt.imshow(grid.permute(1, 2, 0)).figure
plt.axis("off")
plt.tight_layout()

img.savefig("transformed_grid.png", dpi=300)

In [None]:
%reload_ext tensorboard
%tensorboard --logdir='/content/lightning_logs'

# Train model

In [None]:
# Set the CUDA_VISIBLE_DEVICES environment variable
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

metrics_callback = MetricsCallback()

trainer = pl.Trainer(max_epochs=20, callbacks=[metrics_callback], log_every_n_steps=1)
trainer.fit(
    model=model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader
)

In [None]:
# https://github.com/frgfm/torch-cam

# Evaluate the test set


In [None]:
preprocess_test = timm.data.create_transform(**data_cfg, is_training=False)

test_data = datasets.ImageFolder(
    root=test_dir,
    transform=preprocess_test,
    target_transform=None,
)

test_dataloader = DataLoader(
    test_data,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    drop_last=False,
    pin_memory=True,
)

model.eval()
predictions = trainer.predict(model, test_dataloader)

In [None]:
all_labels = torch.tensor(test_dataloader.dataset.targets)
all_labels

In [None]:
probabilities = torch.sigmoid((torch.cat(predictions, dim=0)))

# Threshold probabilities to get binary predictions (0 or 1)
threshold = 0.5
binary_predictions = (probabilities > threshold).float().view(-1)
binary_predictions

In [None]:
acc = (all_labels == binary_predictions).sum().item() / len(all_labels)


metric_f1 = BinaryF1Score()
f1 = metric_f1(all_labels, binary_predictions)


bcm = BinaryConfusionMatrix()
bcm(all_labels, binary_predictions)
fig_, ax_ = bcm.plot()