In [1]:
%load_ext autoreload
%autoreload 2
%load_ext lab_black

In [2]:
import sys
import os

if not ".." in sys.path:
    sys.path.append("..")

In [3]:
from lit_modules.litdatamodules.lit_imagenet import LitImageNetDataModule
from lit_modules.litdatamodules.lit_lamem import LitLaMemDataModule

import lightning as L
from lightning.pytorch.utilities.types import TRAIN_DATALOADERS
from lightning.pytorch.utilities import CombinedLoader
from lightning.pytorch.callbacks import (
    DeviceStatsMonitor,
    StochasticWeightAveraging,
    LearningRateMonitor,
)

import torch
import torch.nn as nn
from torchvision.models import resnet50
from torchvision import transforms
from torch.utils.data import DataLoader
from torchmetrics import Accuracy

from typing import Dict, List
import pandas as pd
from tqdm import tqdm

from lit_modules import LitResNet50
from lightning import Trainer

import os
from torch.utils.tensorboard import SummaryWriter

2024-03-15 03:56:09.928808: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-03-15 03:56:12.334784: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-15 03:56:12.335426: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-15 03:56:12.363863: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-03-15 03:56:12.994669: I tensorflow/core/platform/cpu_feature_g

In [4]:
class LitCombineDataModule(L.LightningDataModule):
    def __init__(
        self,
        root_imagenet: str,
        meta_path_imagenet: str,
        root_lamem: str,
        num_workers: int = 10,
        batch_size: int = 32,
        mode: str = "min_size",
        desired_image_size: int = 224,
    ) -> None:
        super().__init__()
        self.save_hyperparameters("batch_size")
        self.save_hyperparameters("desired_image_size")
        self.mode = mode

        self.data_modules = {
            "regression": LitLaMemDataModule(
                root=root_lamem,
                num_workers=num_workers,
                batch_size=batch_size,
                desired_image_size=desired_image_size,
            ),
            "classification": LitImageNetDataModule(
                root=root_imagenet,
                meta_path=meta_path_imagenet,
                num_workers=num_workers,
                batch_size=batch_size,
                desired_image_size=desired_image_size,
            ),
        }

    def setup(self, stage: str) -> None:

        for key, value in self.data_modules.items():
            value.setup(stage)

    def train_dataloader(self):
        data_loaders = {
            "regression": self.data_modules["regression"].train_dataloader(),
            "classification": self.data_modules["classification"].train_dataloader(),
        }

        return CombinedLoader(data_loaders, mode=self.mode)

    def val_dataloader(self):
        data_loaders = {
            "regression": self.data_modules["regression"].val_dataloader(),
            "classification": self.data_modules["classification"].val_dataloader(),
        }

        return CombinedLoader(data_loaders, mode=self.mode)


class ResNet50(nn.Module):
    def __init__(self):
        super().__init__()
        # Load a pre-trained ResNet-50 model
        self.resnet = resnet50(weights=None)
        self.linear = nn.Linear(1000, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, task):
        x = self.resnet(x)

        if task == "regression":
            x = self.linear(x)
            x = self.sigmoid(x)

        return x

In [5]:
def get_dataloaders():
    root_imagenet = "/datashare/ImageNet/ILSVRC2012"
    meta_path_imagenet = (
        "/home/soroush1/projects/def-kohitij/soroush1/pretrain-imagenet/data/ImageNet"
    )
    root_lamem = "/home/soroush1/projects/def-kohitij/soroush1/pretrain-imagenet/data/lamem/lamem_images/lamem"
    batch_size = 1
    ds = LitCombineDataModule(
        root_imagenet=root_imagenet,
        meta_path_imagenet=meta_path_imagenet,
        root_lamem=root_lamem,
        batch_size=64,
        mode="min_size",
        num_workers=2,
    )

    ds.setup("")

    train_loader = ds.train_dataloader()
    val_loader = ds.val_dataloader()

    return train_loader, val_loader


def calculate_initial_task_losses(model, train_loader, criterion, device):
    with torch.no_grad():
        overfit_batch = []
        for i, (batch, batch_id, dataloader_idx) in enumerate(train_loader):
            overfit_batch.append([batch, batch_id, dataloader_idx])

            if i == 0:
                break

        for batch, batch_id, dataloader_idx in tqdm(
            overfit_batch, total=len(overfit_batch)
        ):
            init_task_losses = []
            for key, value in batch.items():
                x, y = value
                x = x.to(device)
                y = y.to(device)
                outputs = model(x, key)
                loss = criterion[key]["loss"](outputs.squeeze(), y)
                init_task_losses.append(loss)

    print(f"{init_task_losses = }")

    return init_task_losses


def grad_norm_step(model, loss_weights, task_losses, initial_task_losses, alpha=0.16):
    """
    Perform a single step of the GradNorm algorithm.
    - model: the multitask model
    - loss_weights: a list of the weights for each task's loss
    - task_losses: a list of the current losses for each task
    - alpha: GradNorm hyperparameter controlling the strength of the gradients normalization
    """
    # Calculate the gradients for each task's loss with respect to the loss weights
    # This requires keeping the loss weights as parameters to allow gradient computation

    weighted_losses = [
        loss_weights[i] * task_losses[i] for i in range(len(task_losses))
    ]
    total_loss = sum(weighted_losses)
    total_loss.backward(retain_graph=True)  # Retain graph for multiple backward passes

    # Calculate the gradient norms for each task
    W = list(model.parameters())[0]  # Assuming W is the shared parameter(s)
    grad_norms = [torch.norm(W.grad * loss_weights[i]) for i in range(len(task_losses))]

    # Normalize the gradient norms
    mean_grad_norm = torch.mean(torch.stack(grad_norms))
    grad_norms_normalized = [grad_norm / mean_grad_norm for grad_norm in grad_norms]

    # Update the loss weights
    loss_ratios = [
        task_losses[i] / initial_task_losses[i] for i in range(len(task_losses))
    ]
    inverse_training_rates = [loss_ratio**alpha for loss_ratio in loss_ratios]
    mean_inverse_training_rate = torch.mean(torch.tensor(inverse_training_rates))
    new_loss_weights = [
        loss_weights[i] * (inverse_training_rate / mean_inverse_training_rate)
        for i, inverse_training_rate in enumerate(inverse_training_rates)
    ]

    # Reset gradients for next optimization step
    model.zero_grad()

    return new_loss_weights


def train_loop(
    train_loader,
    model,
    loss_weights,
    initial_task_losses,
    criterion,
    optimizer,
    scheduler,
    device,
    model_specification,
    epoch_num: int,
    step: bool = True,
    epoch: bool = True,
    overfit: bool = False,
):
    model.train()
    total_loss = 0.0

    overfit_batch = []
    for i, (batch, batch_id, dataloader_idx) in enumerate(train_loader):
        overfit_batch.append([batch, batch_id, dataloader_idx])

        if i == 0:
            break

    if overfit:
        train_loader = overfit_batch

    for i, (batch, batch_idx, dataloader_idx) in tqdm(
        enumerate(train_loader), total=len(train_loader)
    ):
        task_losses = []
        optimizer.zero_grad()
        for key, value in batch.items():
            x, y = value
            x = x.to(device)
            y = y.to(device)
            outputs = model(x, key)
            loss = criterion[key]["loss"](outputs.squeeze(), y)
            model_specification[f"tr_{key}_loss"].append(loss.item())

            if key == "classification":
                err_top1 = 1 - criterion[key]["top1"](outputs.squeeze(), y)
                err_top5 = 1 - criterion[key]["top5"](outputs.squeeze(), y)
                model_specification["tr_err_top1"].append(err_top1.item())
                model_specification["tr_err_top5"].append(err_top5.item())

            task_losses.append(loss)

        # GradNorm step to adjust gradients and update loss weights
        new_loss_weights = grad_norm_step(
            model, loss_weights, task_losses, initial_task_losses
        )

        model_specification[f"tr_total_loss"].append(sum(task_losses).item())
        # Apply new loss weights and compute total loss for parameter update
        weighted_losses = [
            new_loss_weights[i] * task_losses[i] for i in range(len(task_losses))
        ]
        model_specification[f"tr_weight_loss"].append(sum(weighted_losses).item())

        total_loss = sum(weighted_losses)
        # Backward pass and optimize
        total_loss.backward()
        optimizer.step()

        current_lr = optimizer.param_groups[0]["lr"]
        writer.add_scalar(
            "Learning Rate", current_lr, epoch * len(train_loader) + batch_idx
        )
        if step:
            # Log metrics to TensorBoard
            for key, values in model_specification.items():
                writer.add_scalar(
                    f"Training_step/{key}",
                    values[i],
                    epoch_num * len(train_loader) + batch_idx,
                )

    scheduler.step()

    if epoch:
        # Log metrics to TensorBoard
        for key, values in model_specification.items():
            writer.add_scalar(
                f"Training_epoch/{key}", sum(values) / len(values), epoch_num
            )


def val_loop(
    val_loader,
    model,
    criterion,
    device,
    model_specification,
    epoch_num: int,
    step: bool = True,
    epoch: bool = True,
    overfit: bool = False,
):
    model.eval()

    overfit_batch = []
    for i, (batch, batch_id, dataloader_idx) in enumerate(val_loader):
        overfit_batch.append([batch, batch_id, dataloader_idx])

        if i == 0:
            break

    if overfit:
        val_loader = overfit_batch

    with torch.no_grad():
        for i, (batch, batch_idx, dataloader_idx) in tqdm(
            enumerate(val_loader), total=len(val_loader)
        ):
            task_losses = []
            for key, value in batch.items():
                x, y = value
                x = x.to(device)
                y = y.to(device)
                outputs = model(x, key)
                loss = criterion[key]["loss"](outputs.squeeze(), y)
                model_specification[f"val_{key}_loss"].append(loss.item())

                if key == "classification":
                    err_top1 = 1 - criterion[key]["top1"](outputs.squeeze(), y)
                    err_top5 = 1 - criterion[key]["top5"](outputs.squeeze(), y)
                    model_specification["val_err_top1"].append(err_top1.item())
                    model_specification["val_err_top5"].append(err_top5.item())

            if step:
                # Log metrics to TensorBoard
                for key, values in model_specification.items():
                    writer.add_scalar(
                        f"Validation_step/{key}",
                        values[i],
                        epoch_num * len(val_loader) + batch_idx,
                    )

        if epoch:
            # Log metrics to TensorBoard
            for key, values in model_specification.items():
                writer.add_scalar(
                    f"Validation_epoch/{key}", sum(values) / len(values), epoch_num
                )


def get_next_version_number(base_dir, experiment_name):
    experiment_path = os.path.join(base_dir, experiment_name)
    if not os.path.exists(experiment_path):
        os.makedirs(experiment_path)
        return 0

    existing_versions = [
        d
        for d in os.listdir(experiment_path)
        if os.path.isdir(os.path.join(experiment_path, d))
    ]
    if not existing_versions:
        return 0

    highest_version = max([int(v.split("_")[-1]) for v in existing_versions])
    return highest_version + 1

In [None]:
epoch = 8
device = "cuda:0" if torch.cuda.is_available else "cpu"

train_loader, val_loader = get_dataloaders()
model = ResNet50()
model.to(device)

# Assuming initial_task_losses is a pre-computed list of initial losses for each task
# loss_weights should be initialized as torch parameters to allow gradient updates
loss_weights = [
    torch.nn.Parameter(torch.ones(1, requires_grad=True)).cuda() for _ in range(2)
]  # 2 tasks

optimizer = torch.optim.Adam(
    params=model.parameters(), betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-3
)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

classes = 1000
criterion = {
    "regression": {"loss": nn.MSELoss()},
    "classification": {
        "loss": nn.CrossEntropyLoss(),
        "top1": Accuracy(task="multiclass", num_classes=classes, top_k=1).to(device),
        "top5": Accuracy(task="multiclass", num_classes=classes, top_k=5).to(device),
    },
}

initial_task_losses = calculate_initial_task_losses(
    model, train_loader, criterion, device
)

tr_model_specification = {
    "tr_regression_loss": [],
    "tr_classification_loss": [],
    "tr_total_loss": [],
    "tr_weight_loss": [],
    "tr_err_top1": [],
    "tr_err_top5": [],
}

val_model_specification = {
    "val_regression_loss": [],
    "val_classification_loss": [],
    "val_err_top1": [],
    "val_err_top5": [],
}

# Base directory where all TensorBoard logs are stored
base_dir = "runs"

# Name of the experiment
experiment_name = "your_experiment_name"

# Get the next version number for the experiment
next_version_number = get_next_version_number(base_dir, experiment_name)

log_dir = os.path.join(base_dir, experiment_name, f"version_{next_version_number}")

# Initialize the TensorBoard writer with the unique log directory
writer = SummaryWriter(log_dir)

overfit = False

for i in tqdm(range(epoch)):
    train_loop(
        train_loader,
        model,
        loss_weights,
        initial_task_losses,
        criterion,
        optimizer,
        scheduler,
        device,
        tr_model_specification,
        epoch_num=i,
        overfit=overfit,
    )

    val_loop(
        val_loader,
        model,
        criterion,
        device,
        val_model_specification,
        epoch_num=i,
        overfit=overfit,
    )

In [None]:
sample_data = torch.rand((10, 3, 224, 224))

model = ResNet50()
x = model(sample_data, "regression")

print(sample_data.size())
print(x.size())

x = model(sample_data, "classification")
print(x.size())

In [None]:
for batch, batch_id, dataloader_idx in train_loader:
    print(f"{batch_id = }, {dataloader_idx = }")

    for key, value in batch.items():
        x, y = value
        print(f"{key}: {x.size()}, {y.size()}")
        