In [1]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/workspace/illusion/

In [2]:
! pip install -r requirements.txt

In [3]:
import os
import torch
import wandb
import torch.nn as nn
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning import seed_everything
from sklearn.metrics import accuracy_score
from torch.utils.data import ConcatDataset
import torchvision
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
from math import sqrt
from timm import create_model

# Seed for reproducibility
seed_everything(42)


class Model(LightningModule):
    def __init__(self, model_name, steps_per_epoch, num_classes, lr):
        super(Model, self).__init__()

        self.save_hyperparameters()

        # Load pre-trained ResNet50 model from timm with the correct number of classes
        self.model = create_model(model_name, pretrained=False, num_classes=num_classes)

        # Loss function and learning rate
        self.criterion = nn.CrossEntropyLoss()
        self.lr = lr
        self.steps_per_epoch = steps_per_epoch

    def forward(self, x):
        # Forward pass through the entire model
        return self.model(x)

    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        acc = accuracy_score(labels.cpu(), outputs.argmax(dim=1).cpu())
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc', acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        acc = accuracy_score(labels.cpu(), outputs.argmax(dim=1).cpu())
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=self.lr, steps_per_epoch=self.steps_per_epoch, epochs=self.trainer.max_epochs)
        return {"optimizer": optimizer, "lr_scheduler": scheduler}


In [4]:
from data_indl_and_cifar10 import trainloader_combined, testloader_combined

class ModifiedModel(Model):
    def validation_step(self, batch, batch_idx):
        images, labels = batch

        # Forward pass for all images
        outputs = self(images)

        # Compute loss and accuracy for all labels (val_all)
        loss_all = self.criterion(outputs, labels)
        preds_all = torch.argmax(outputs, dim=1)
        acc_all = (preds_all == labels).float().mean()

        # Filter labels within the range of 0 to 9 (val_B)
        valid_indices = (labels >= 0) & (labels <= 9)

        if valid_indices.any():
            valid_images = images[valid_indices]
            valid_labels = labels[valid_indices]
            logits_valid = self(valid_images)  # Forward pass for filtered images

            loss_valid = F.cross_entropy(logits_valid, valid_labels)
            preds_valid = torch.argmax(logits_valid, dim=1)
            acc_valid = (preds_valid == valid_labels).float().mean()

            # Log validation metrics for filtered labels 0-9
            self.log('val_loss', loss_valid, prog_bar=True)
            self.log('val_acc', acc_valid, prog_bar=True)

        else:
            loss_valid = None
            acc_valid = None

        # Filter labels outside the range of 0 to 9 (val_exc for exclusive labels 10-11)
        exc_indices = (labels >= 10) & (labels <= 11)

        if exc_indices.any():
            exc_images = images[exc_indices]
            exc_labels = labels[exc_indices]
            logits_exc = self(exc_images)  # Forward pass for excluded images

            loss_exc = F.cross_entropy(logits_exc, exc_labels)
            preds_exc = torch.argmax(logits_exc, dim=1)
            acc_exc = (preds_exc == exc_labels).float().mean()

            # Log validation metrics for labels 10-11
            self.log('val_exc_loss', loss_exc, prog_bar=True)
            self.log('val_exc_acc', acc_exc, prog_bar=True)

        else:
            loss_exc = None
            acc_exc = None

        # Log overall validation loss and accuracy (for all labels)
        self.log('val_all_loss', loss_all, prog_bar=True)
        self.log('val_all_acc', acc_all, prog_bar=True)

        # Return all the metrics
        return {
            'val_loss': loss_valid,
            'val_acc': acc_valid,
            'val_exc_loss': loss_exc,
            'val_exc_acc': acc_exc,
            'val_all_loss': loss_all,
            'val_all_acc': acc_all,
        }


def main(model_name):

    # Initialize Wandb logger with a careful naming convention for the model
    wandb_logger = WandbLogger(project="illusion_augmented_models", name=f"model_m_X_{model_name}", log_model=True)
    # Experiment name and setup
    exp_name = "illusion_augmented_image_classification_model"

    # Callbacks
    checkpoint_callback = ModelCheckpoint(
        monitor="val_acc",
        dirpath=f"./models/m_X_{model_name}/",
        filename="{epoch:02d}-{val_acc:.2f}",
        save_top_k=1,
        mode="max",
    )
    lr_monitor = LearningRateMonitor(logging_interval="step")

    # Early stopping callback to stop training when the monitored metric has stopped improving
    early_stop_callback = EarlyStopping(
        monitor='val_loss',  # or any other metric you are monitoring
        patience=3,  # Number of epochs to wait after the last improvement
        mode='min',  # 'min' for loss, 'max' for accuracy
        verbose=True
    )

    # Path for latest checkpoint
    checkpoint_dir = f"./models/m_X_{model_name}/"
    latest_checkpoint = None

    # Check if a checkpoint exists
    if os.path.exists(checkpoint_dir):
        checkpoints = os.listdir(checkpoint_dir)
        if checkpoints:
            latest_checkpoint = max(
                [os.path.join(checkpoint_dir, ckpt) for ckpt in checkpoints],
                key=os.path.getctime
            )

    # Training model instance
    model = ModifiedModel(model_name, steps_per_epoch=len(trainloader_combined), num_classes=12, lr=1e-4)

    # Trainer configuration
    trainer = Trainer(
        logger=wandb_logger,
        callbacks=[checkpoint_callback, lr_monitor, early_stop_callback],
        accelerator="auto",
        devices=1
    )

    # Model training
    trainer.fit(model, train_dataloaders=trainloader_combined, val_dataloaders=testloader_combined)

Files already downloaded and verified
Files already downloaded and verified
Loading train dataset from database/indl/train/combined_trainset.pth and test dataset from database/indl/test/combined_testset.pth...


  train_data = torch.load(train_pth_file)
  test_data = torch.load(test_pth_file)


In [None]:
for model_name in ["convnextv2_huge", "convnext_xxlarge", "resnet50"]:
    main(model_name)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/utilities.py:72: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /content/drive/MyDrive/workspace/illusion/models/m_X_convnextv2_huge exists and is not empty.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | model     | ConvNeXt         | 657 M  | train
1 | criterion | CrossEntropyLoss | 0      | train
-------------------------------------------------------
657 M     Trainable params
0         Non-trainable params
657 M     Total params
2,630.026 Total estimated model params size (MB)
466       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved. New best score: 1.445


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.083 >= min_delta = 0.0. New best score: 1.362


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Monitored metric val_loss did not improve in the last 3 records. Best score: 1.362. Signaling Trainer to stop.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/utilities.py:72: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /content/drive/MyDrive/wo

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved. New best score: 1.782


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.127 >= min_delta = 0.0. New best score: 1.655


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.060 >= min_delta = 0.0. New best score: 1.595


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.044 >= min_delta = 0.0. New best score: 1.551


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.034 >= min_delta = 0.0. New best score: 1.517


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.026 >= min_delta = 0.0. New best score: 1.490


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.031 >= min_delta = 0.0. New best score: 1.459


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.019 >= min_delta = 0.0. New best score: 1.441
