# FIT3164 - MDS08


## Setup


In [None]:
!python pip install -r requirements.txt

In [None]:
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

In [None]:
# uncomment if using cpu training
# !pip3 install torch torchvision torchaudio

## Download Data


In [None]:
!pip3 install zenodo_get

In [None]:
!zenodo_get 4756317

## Fine-tuning


In [None]:
# all imports
import os
from GPUtil import showUtilization as gpu_usage
from numba import cuda


import cv2
import logging



import numpy as np
import torch
from torch import nn, optim



from torch.utils.data import Dataset, DataLoader


from torchvision import transforms


from sklearn.model_selection import train_test_split


from PIL import Image
import yaml
import csv
from torch.nn.utils.rnn import pad_sequence


import pytorch_lightning as pl


from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping


from pytorch_lightning.loggers import TensorBoardLogger


from transformers import ViTImageProcessor, ViTModel, ViTConfig


import tensorboard


import tensorboardX


import yaml


from omegaconf import OmegaConf


import glob
import random



from base.base_loader import BaseDataset


from datasets.loader_utils import (
    multi_label_to_index,
    pad_video,
    video_transforms,
    sampling,

    VideoRandomResizedCrop,
    read_gsl_continuous,

    gsl_context,
    read_bounding_box,
)

In [None]:
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"

In [None]:
# check if cuda is available
print(torch.cuda.is_available())

In [None]:
# uncomment and run this cell to use cpu
# torch.cuda.is_available = lambda : False


# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# check if cuda is available (should say False if the above cell is run)
print(torch.cuda.is_available())

In [None]:
# set random seed
pl.seed_everything(42)

In [None]:
# reference the vision transformer that will be used (86 million params)
HUGGING_MODEL_NAME = "google/vit-base-patch16-224-in21k"


# HUGGING_MODEL_NAME = 'google/vit-large-patch16-224-in21k'


# HUGGING_MODEL_NAME = 'google/vit-huge-patch14-224-in21k'



# BATCH_SIZE = 32


# NUM_WORKERS = 2


# MAX_EPOCHS = 10


# LEARNING_RATE = 3e-4


# SEQUENCE_LENGTH = 16

In [None]:
# set image processor
image_processor = ViTImageProcessor.from_pretrained(HUGGING_MODEL_NAME)

In [None]:
# too much work to remove args from the code, so just create a dummy class and pass it to the functions
class Args:

    def __init__(self, config):
        self.cwd = os.getcwd()

        self.input_data = config["dataset"]["input_data"]

        self.return_context = False

In [None]:
# set file paths for fine tuning
feats_path = "gsl_cont_features/"


train_prefix = "train"


dev_prefix = "dev"


validation_prefix = "validation"


test_prefix = "test"



train_filepath = "files/GSL_continuous/gsl_split_SI_train.csv"


dev_filepath = "files/GSL_continuous/gsl_split_SI_dev.csv"


validation_filepath = "files/GSL_continuous/gsl_split_SI_dev.csv"


test_filepath = "files/GSL_continuous/gsl_split_SI_test.csv"



# dataset class inherits from BaseDataset


class GSL_SI(BaseDataset):

    def __init__(self, config, args, mode, classes):
        """






        Args:





            config:





            args:





            mode:





            classes:
        """


        super(GSL_SI, self).__init__(config, args, mode, classes)


        self.config = config["dataset"]
        self.mode = mode


        self.dim = tuple(self.config["dim"])
        self.classes = classes


        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}

        self.num_classes = self.config["classes"]

        self.seq_length = self.config[self.mode]["seq_length"]

        self.normalize = self.config["normalize"]

        self.padding = self.config["padding"]

        self.augmentation = self.config[self.mode]["augmentation"]

        self.return_context = args.return_context


        self.rgb_data_path = os.path.join(
            self.config["input_data"], self.config["images_path"]
        )

        self.depth_data_path = os.path.join(
            self.config["input_data"], self.config["depth_path"]
        )


        if self.mode == train_prefix:

            self.list_IDs, self.list_glosses = read_gsl_continuous(
                os.path.join(args.cwd, train_filepath)
            )

        elif self.mode == validation_prefix:

            self.list_IDs, self.list_glosses = read_gsl_continuous(
                os.path.join(args.cwd, validation_filepath)
            )

        elif self.mode == test_prefix:

            self.list_IDs, self.list_glosses = read_gsl_continuous(
                os.path.join(args.cwd, test_filepath)
            )


        print(f"{len(self.list_IDs)} {self.mode} instances")


        self.bbox = read_bounding_box(
            os.path.join(args.cwd, "files/GSL_continuous/bbox_for_gsl_continuous.txt")
        )


        self.context = gsl_context(self.list_IDs, self.list_glosses)


        if self.config["modality"] == "full":

            self.data_path = os.path.join(
                self.config["input_data"], self.config["images_path"]
            )

            self.get = self.video_loader

        elif self.config["modality"] == "features":

            self.data_path = os.path.join(
                self.config["input_data"], self.config["features_path"]
            )

            self.get = self.feature_loader


    def __len__(self):

        return len(self.list_IDs)


    def __getitem__(self, index):

        rgb_x = self.load_video_sequence(path=self.list_IDs[index], img_type="jpg")

        depth_x = self.load_video_sequence(
            path=self.list_IDs[index].replace("color", "depth"), img_type="jpg"
        )

        y = multi_label_to_index(
            classes=self.classes, target_labels=self.list_glosses[index]
        )


        print(f"In __getitem__: RGB shape: {rgb_x.shape}")

        print(f"In __getitem__: Depth shape: {depth_x.shape}")

        print(f"In __getitem__: Label shape: {y.shape}")


        return rgb_x, depth_x, y


    def feature_loader(self, index):

        folder_path = os.path.join(self.data_path, self.list_IDs[index])

        # print(folder_path)


        y = multi_label_to_index(
            classes=self.classes, target_labels=self.list_glosses[index]
        )

        if self.context[index] != None:


            c = multi_label_to_index(
                classes=self.classes, target_labels=self.context[index]
            )

        else:

            c = torch.tensor([0], dtype=torch.int)

        x = torch.FloatTensor(np.load(folder_path + ".npy")).squeeze(0)

        if self.return_context:

            return x, [y, c]

        return x, y


    def video_loader(self, index):


        x = self.load_video_sequence(path=self.list_IDs[index], img_type="jpg")

        y = multi_label_to_index(
            classes=self.classes, target_labels=self.list_glosses[index]
        )


        return x, y


    def load_video_sequence(self, path, img_type="jpg", is_depth=False):

        data_path = self.depth_data_path if is_depth else self.rgb_data_path

        images = sorted(
            glob.glob(
                os.path.join(
                    data_path,
                    path,
                )
                + "/*"
                + img_type
            )
        )


        h_flip = False

        img_sequence = []

        # print(images)

        if len(images) < 1:

            print(os.path.join(self.data_path, path))

        bbox = self.bbox.get(path)


        if self.augmentation:

            ## training set temporal  AUGMENTATION

            temporal_augmentation = int(
                (np.random.randint(80, 100) / 100.0) * len(images)
            )

            if temporal_augmentation > 15:

                images = sorted(random.sample(images, k=temporal_augmentation))

            if len(images) > self.seq_length:

                # random frame sampling

                images = sorted(random.sample(images, k=self.seq_length))


        else:

            # test uniform sampling


            if len(images) > self.seq_length:

                images = sorted(sampling(images, self.seq_length))


        i = np.random.randint(0, 30)

        j = np.random.randint(0, 30)

        brightness = 1 + random.uniform(-0.2, +0.2)

        contrast = 1 + random.uniform(-0.2, +0.2)

        hue = random.uniform(0, 1) / 10.0

        # r_resize = ((112, 112))

        r_resize = (224, 224)

        crop_or_bbox = random.uniform(0, 1) > 0.5

        to_flip = random.uniform(0, 1) > 1

        grayscale = random.uniform(0, 1) > 0.9

        t1 = VideoRandomResizedCrop(self.dim[0], scale=(0.9, 1.0), ratio=(0.8, 1.2))

        for img_path in images:


            frame_o = Image.open(img_path).convert("RGB")


            crop_size = 120

            ## CROP BOUNDING BOX

            ## CROP BOUNDING BOX


            frame1 = np.array(frame_o)


            frame1 = frame1[:, crop_size : 648 - crop_size]

            frame = Image.fromarray(frame1)


            if self.augmentation:


                ## training set DATA AUGMENTATION


                frame = frame.resize(r_resize)


                img_tensor = video_transforms(
                    img=frame,
                    i=i,
                    j=j,
                    bright=brightness,
                    cont=contrast,
                    h=hue,
                    dim=self.dim,

                    resized_crop=t1,
                    augmentation=True,
                    normalize=self.normalize,
                    crop=crop_or_bbox,
                    to_flip=to_flip,
                    grayscale=grayscale,
                )

                img_sequence.append(img_tensor)

            else:

                # TEST set  NO DATA AUGMENTATION

                if is_depth:

                    # Resize depth frame to match RGB dimensions

                    frame = frame.resize(self.dim)

                else:

                    frame = frame.resize(self.dim)


                img_tensor = video_transforms(
                    img=frame,
                    i=i,
                    j=j,
                    bright=1,
                    cont=1,
                    h=0,
                    dim=self.dim,
                    augmentation=False,
                    normalize=self.normalize,
                )

                img_sequence.append(img_tensor)

        pad_len = self.seq_length - len(images)


        X1 = torch.stack(img_sequence).float()


        if self.padding:

            X1 = pad_video(X1, padding_size=pad_len, padding_type="zeros")


        if len(images) < 25:

            X1 = pad_video(X1, padding_size=25 - len(images), padding_type="zeros")

        return X1

In [None]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import ViTModel, ViTConfig
import torch.nn.functional as F


class GSLDualViT(pl.LightningModule):
    def __init__(self, num_classes, config):
        super().__init__()
        self.config = config
        self.class_names = class_names
        self.num_classes = len(class_names)
        # self.logger = logger
        self.learning_rate = config["trainer"]["optimizer"]["lr"]
        self.sequence_length = config["dataset"]["train"]["seq_length"]

        # RGB ViT
        self.rgb_vit = ViTModel.from_pretrained(HUGGING_MODEL_NAME)
        self.rgb_vit.train()

        for param in self.rgb_vit.parameters():
            param.requires_grad = True

        # Depth ViT (initialize with same config but different input channels)
        depth_config = ViTConfig.from_pretrained(HUGGING_MODEL_NAME)
        depth_config.num_channels = 3
        self.depth_vit = ViTModel(depth_config)

        # Fusion layer
        self.fusion = nn.Linear(
            self.rgb_vit.config.hidden_size * 2, self.rgb_vit.config.hidden_size
        )

        # Temporal modeling
        self.lstm = nn.LSTM(
            self.rgb_vit.config.hidden_size,
            self.config["trainer"]["model"]["backbone"]["rnn"]["hidden_size"],
            num_layers=self.config["trainer"]["model"]["backbone"]["rnn"]["num_layers"],
            bidirectional=self.config["trainer"]["model"]["backbone"]["rnn"][
                "bidirectional"
            ],
            batch_first=True,
            dropout=self.config["trainer"]["model"]["backbone"]["rnn"]["dropout"],
        )

        # Classification head
        lstm_output_size = self.config["trainer"]["model"]["backbone"]["rnn"][
            "hidden_size"
        ]
        if self.config["trainer"]["model"]["backbone"]["rnn"]["bidirectional"]:
            lstm_output_size *= 2
        self.classifier = nn.Linear(lstm_output_size, num_classes)

    def forward(self, rgb_sequence, depth_sequence=None):
        print(
            f"Forward pass input shapes: RGB {rgb_sequence.shape}, Depth {depth_sequence.shape if depth_sequence is not None else None}"
        )
        batch_size, seq_len, channels, height, width = rgb_sequence.shape
        fused_features = []
        for i in range(seq_len):
            print(f"Processing frame {i}")
            try:
                rgb_features = self.rgb_vit(rgb_sequence[:, i]).last_hidden_state[:, 0]
                print(f"RGB features shape: {rgb_features.shape}")

                if depth_sequence is not None:
                    depth_features = self.depth_vit(
                        depth_sequence[:, i]
                    ).last_hidden_state[:, 0]
                    print(f"Depth features shape: {depth_features.shape}")
                    fused = self.fusion(
                        torch.cat([rgb_features, depth_features], dim=-1)
                    )
                else:
                    fused = rgb_features

                print(f"Fused features shape: {fused.shape}")
                fused_features.append(fused)
            except Exception as e:
                print(f"Error processing frame {i}: {str(e)}")

        print(f"Number of fused features: {len(fused_features)}")
        if len(fused_features) == 0:
            print("fused_features is empty!")
            return None

        fused_sequence = torch.stack(fused_features, dim=1)
        print(f"Fused sequence shape: {fused_sequence.shape}")
        lstm_out, _ = self.lstm(fused_sequence)
        print(f"LSTM output shape: {lstm_out.shape}")
        output = self.classifier(lstm_out[:, -1])
        print(f"Final output shape: {output.shape}")
        return fused_sequence  # Changed from 'output' to 'fused_sequence'

    # def forward(self, rgb_sequence, depth_sequence=None):
    #     batch_size, seq_len, channels, height, width = rgb_sequence.shape
    #     print(f"In forward: RGB shape: {rgb_sequence.shape}")

    # Process each frame in the sequence
    # fused_features = []
    # for i in range(seq_len):
    #     rgb_features = self.rgb_vit(rgb_sequence[:, i]).last_hidden_state[:, 0]  # Use CLS token

    #     if depth_sequence is not None:
    #         print(f"In forward: Depth shape: {depth_sequence.shape}")
    #         depth_sequence_resized = F.interpolate(depth_sequence[:, i], size=(height, width), mode='bilinear', align_corners=False)
    #         depth_features = self.depth_vit(depth_sequence_resized).last_hidden_state[:, 0]  # Use CLS token
    #         fused = self.fusion(torch.cat([rgb_features, depth_features], dim=-1))
    #     else:
    #         fused = rgb_features

    #     fused_features.append(fused)

    # fused_sequence = torch.stack(fused_features, dim=1)

    # Temporal modeling
    # lstm_out, _ = self.lstm(fused_sequence)

    # Classification
    # output = self.classifier(lstm_out[:, -1])  # Use last timestep for classification
    # print(f"In forward: Output shape: {output.shape}")
    # return output

    def training_step(self, batch, batch_idx):
        rgb_sequence, depth_sequence, labels = batch
        print(f"RGB shape: {rgb_sequence.shape}")
        print(f"Depth shape: {depth_sequence.shape}")
        print(f"Labels shape: {labels.shape}")

        logits = self(rgb_sequence, depth_sequence)
        print(f"Logits shape: {logits.shape if logits is not None else None}")

        if logits is None:
            print("Logits is None! Check the forward method.")
            return None

        # Assuming you've added a classifier layer to handle the output from the forward method
        # If not, you might need to add: self.classifier = nn.Linear(vit_output_dim, num_classes)
        logits = self.classifier(
            logits[:, -1]
        )  # Using the last frame's features for classification

        loss = nn.functional.cross_entropy(logits, labels)

        # Calculate accuracy
        preds = torch.argmax(logits, dim=1)
        acc = (preds == labels).float().mean()

        # Log metrics
        self.log(
            "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )
        self.log(
            "train_acc", acc, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )

        return loss

    def validation_step(self, batch, batch_idx):
        rgb_sequence, depth_sequence, labels = batch
        logits = self(rgb_sequence, depth_sequence)
        loss = nn.functional.cross_entropy(logits, labels)

        # Calculate accuracy
        preds = torch.argmax(logits, dim=1)
        acc = (preds == labels).float().mean()

        # Log metrics
        self.log(
            "val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )
        self.log(
            "val_acc", acc, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )

        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(
            self.parameters(),
            lr=self.learning_rate,
            weight_decay=self.config["trainer"]["optimizer"]["weight_decay"],
        )
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            factor=self.config["trainer"]["scheduler"]["scheduler_factor"],
            patience=self.config["trainer"]["scheduler"]["scheduler_patience"],
            min_lr=self.config["trainer"]["scheduler"]["scheduler_min_lr"],
            verbose=True,
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler,
            "monitor": "val_loss",
        }

In [None]:
def load_config(config_path):
    with open(config_path, "r") as file:
        config = yaml.safe_load(file)
    return config


config = load_config("configs/dummy.yaml")

In [None]:
def load_classes(file_path):
    with open(file_path, "r", encoding="utf-8") as f:
        reader = csv.reader(f)
        classes = [row[0] for row in reader]
    return classes


class_names = load_classes(config["dataset"]["classes_filepath"])

In [None]:
def collate_fn(batch):
    rgb_sequences, depth_sequences, labels = zip(*batch)

    # pad the sequences to the maximum length within the batch
    rgb_padded = pad_sequence(rgb_sequences, batch_first=True)
    depth_padded = pad_sequence(depth_sequences, batch_first=True)

    # get the maximum label length within the batch
    max_label_length = max(len(label) for label in labels)

    # pad the labels to the maximum length
    padded_labels = []
    for label in labels:
        padded_label = label.tolist() + [-1] * (max_label_length - len(label))
        padded_labels.append(torch.tensor(padded_label))

    labels = torch.stack(padded_labels)

    return rgb_padded, depth_padded, labels

In [None]:
args = Args(config)

# set up the datasets and dataloaders
train_dataset = GSL_SI(config, args, mode="train", classes=class_names)
val_dataset = GSL_SI(config, args, mode="validation", classes=class_names)

train_loader = DataLoader(
    train_dataset,
    batch_size=config["dataset"]["train"]["batch_size"],
    shuffle=config["dataset"]["train"]["shuffle"],
    num_workers=config["dataset"]["train"]["num_workers"],
    collate_fn=collate_fn,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=config["dataset"]["validation"]["batch_size"],
    shuffle=config["dataset"]["validation"]["shuffle"],
    num_workers=config["dataset"]["validation"]["num_workers"],
    collate_fn=collate_fn,
)

In [None]:
# initialise model
model = GSLDualViT(num_classes=config["dataset"]["classes"], config=config)

In [None]:
# trainer setup and initialization
trainer = pl.Trainer(
    max_epochs=config["trainer"]["epochs"],
    # accelerator='cpu' if config['trainer']['cuda'] else 'cpu', # uncomment this line to use CPU, and comment the next line
    accelerator="gpu" if config["trainer"]["cuda"] else "cpu",
    devices=1,
    precision="16-mixed",
    accumulate_grad_batches=4,
    num_sanity_val_steps=0,
    callbacks=[
        ModelCheckpoint(
            dirpath="checkpoints",
            filename="gsl-dualvit-{epoch:02d}-{val_loss:.2f}",
            save_top_k=3,
            monitor="val_loss",
            mode="min",
        ),
        EarlyStopping(monitor="val_loss", patience=5, mode="min"),
    ],
    enable_progress_bar=True,
    log_every_n_steps=1,
    logger=TensorBoardLogger("logs/", name=config["trainer"]["logger"]),
)

In [None]:
# get available memory info
torch.cuda.mem_get_info()

In [None]:
# free up gpu memory before training
def free_gpu_cache():


    print("Initial GPU Usage")
    gpu_usage()


    torch.cuda.empty_cache()


    cuda.select_device(0)


    cuda.close()

    cuda.select_device(0)


    print("GPU Usage after emptying the cache")
    gpu_usage()


free_gpu_cache()

In [None]:
# start training
trainer.fit(model, train_loader, val_loader)

In [None]:
# # Load the best model
# best_model_path = checkpoint_callback.best_model_path
# best_model = GSLDualViT.load_from_checkpoint(best_model_path)

In [None]:
# # Evaluate on test set
# test_dataset = GSLRGBDDataset(root_dir='path/to/GSL/dataset', split='test', transform=val_transform)
# test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

# trainer.test(best_model, test_loader)

In [None]:
# # Function to perform predictions
# def perform_predictions(model, test_loader, num_samples=10):
#     model.eval()
#     all_preds = []
#     all_labels = []
#     rgb_images = []
#     depth_images = []

#     with torch.no_grad():
#         for batch in test_loader:
#             rgb_seq, depth_seq, y = batch
#             logits = model(rgb_seq, depth_seq)
#             preds = logits.argmax(dim=-1)
#             all_preds.extend(preds.cpu().numpy())
#             all_labels.extend(y.cpu().numpy())
#             rgb_images.extend(rgb_seq[:, 0].cpu().numpy())  # Take first frame of each sequence
#             depth_images.extend(depth_seq[:, 0].cpu().numpy())

#             if len(all_preds) >= num_samples:
#                 break

#     return rgb_images[:num_samples], depth_images[:num_samples], all_preds[:num_samples], all_labels[:num_samples]

In [None]:
# # Visualize predictions
# import matplotlib.pyplot as plt

# rgb_images, depth_images, preds, labels = perform_predictions(best_model, test_loader)

# fig, axes = plt.subplots(2, 5, figsize=(20, 8))
# axes = axes.ravel()

# for i, (rgb_img, depth_img, pred, label) in enumerate(zip(rgb_images, depth_images, preds, labels)):
#     rgb_img = np.transpose(rgb_img, (1, 2, 0))
#     rgb_img = (rgb_img * image_processor.image_std) + image_processor.image_mean
#     rgb_img = np.clip(rgb_img, 0, 1)

#     depth_img = np.squeeze(depth_img)

#     axes[i].imshow(rgb_img)
#     axes[i].imshow(depth_img, cmap='gray', alpha=0.5)
#     axes[i].set_title(f'Pred: {test_dataset.classes[pred]}\nTrue: {test_dataset.classes[label]}')
#     axes[i].axis('off')

# plt.tight_layout()
# plt.show()