# FIT3164 - MDS08


## Setup


In [None]:
# import os
# from google.colab import drive

# drive.mount("/content/drive")

### Install Packages

In [None]:
!pip cache purge

In [None]:
!pip install torch

In [None]:
!pip install pytorch-lightning omegaconf torchinfo

In [None]:
!pip install albumentations

In [None]:
!pip install transformers

## Fine-tuning


### Imports

In [3]:
!pip install torchvision



In [4]:
import os

In [5]:
!pip install torchmetrics



In [6]:
import torch
from torch import nn, optim
from torch.optim.lr_scheduler import OneCycleLR
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF
from torchvision import transforms
from torchmetrics.classification import MultilabelAccuracy, MultilabelF1Score, MultilabelPrecision, MultilabelRecall

In [7]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, ModelSummary
from pytorch_lightning.loggers import TensorBoardLogger

In [8]:
from transformers import VivitModel, VivitImageProcessor, ViTImageProcessor, ViTForImageClassification, ViTModel, ViTConfig

In [9]:
from omegaconf import OmegaConf

In [10]:
import yaml

In [11]:
import csv

In [12]:
import glob

In [13]:
import numpy as np

In [14]:
import time

In [15]:
import random

In [16]:
import cv2

In [17]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

### Other

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [19]:
pl.seed_everything(42)

Seed set to 42


42

In [20]:
HUGGING_MODEL_NAME = 'google/vivit-b-16x2'

### Config, Classes, Collate

In [21]:
def load_config(config_path):
    with open(config_path, "r") as file:
        config = yaml.safe_load(file)
    return config

In [22]:
config = load_config("D:\GitHub\FIT3164\model\dummy.yaml")

In [23]:
def load_classes(file_path):
    with open(file_path, "r", encoding="utf-8") as f:
        reader = csv.reader(f)
        classes = [row[0] for row in reader]
    return classes

In [24]:
class_names = load_classes(config["dataset"]["classes_filepath"])

In [25]:
def collate_fn(batch):
    x_batch, y_batch = zip(*batch)
    x_batch = torch.stack(x_batch)
    y_batch = torch.stack(y_batch)
    return x_batch, y_batch

### Dataset

#### Helper Functions

In [26]:
def multi_label_to_index(classes, target_labels):
    class_to_index = {word: i for i, word in enumerate(classes)}
    indexes = [class_to_index[word] for word in target_labels.strip().split() if word in class_to_index]
    return torch.tensor(indexes, dtype=torch.int)

In [27]:
def pad_video(x, target_length):
    current_length = x.size(0)
    if current_length >= target_length:
        return x[:target_length]
    else:
        padding_size = target_length - current_length
        return F.pad(x, (0, 0, 0, 0, 0, 0, 0, padding_size))

#### DataLoader

In [28]:
class GSL_SI_ViViT(Dataset):
    def __init__(self, config, mode, classes_path, processor):
        self.config = config['dataset']
        self.mode = mode
        self.seq_length = self.config[mode]['seq_length']
        self.augmentation = self.config[mode]['augmentation']
        self.data_path = os.path.join(self.config['input_data'], self.config['images_path'])
        self.to_tensor = ToTensorV2()
        self.processor = processor

        self.indices, self.classes, self.id2w = self.read_gsl_continuous_classes(classes_path)
        self.num_classes = len(self.classes)

        filepath = self.config[f'{mode}_filepath']
        self.list_IDs, self.list_glosses = self.read_gsl_continuous(filepath)
        print(f"{len(self.list_IDs)} {self.mode} instances")

        self.bbox_data = self.bounding_box_handler(self.config['bbox_filepath'])

        self.transform = self.get_transforms()

    def __len__(self):
        return len(self.list_IDs)

    # def __getitem__(self, index):
    #     video_path = self.list_IDs[index]
    #     try:
    #         frames = self.load_video_sequence(video_path)
    #         if len(frames) == 0:
    #             print(f"Skipping index {index} due to empty video sequence")
    #             return self.__getitem__((index + 1) % len(self))
    #     except FileNotFoundError as e:
    #         print(str(e))
    #         return self.__getitem__((index + 1) % len(self))

    #     processed_frames = self.process_frames(frames)

    #     labels = self.get_labels(self.list_glosses[index])

    #     return processed_frames, labels

    def __getitem__(self, index):
        video_path = self.list_IDs[index]
        try:
            frames = self.load_video_sequence(video_path)
            if len(frames) == 0:
                print(f"Skipping index {index} due to empty video sequence")
                return self.__getitem__((index + 1) % len(self))
        except FileNotFoundError as e:
            print(str(e))
            return self.__getitem__((index + 1) % len(self))

        processed_frames = self.process_frames(frames)

        labels = self.get_labels(self.list_glosses[index])

        return processed_frames, labels

    def read_gsl_continuous_classes(self, path):
        with open(path, 'r', encoding='utf-8') as file:
            classes = ['blank'] + file.read().splitlines()

        indices = list(range(len(classes)))
        id2w = dict(zip(indices, classes))

        return indices, classes, id2w

    def read_gsl_continuous(self, csv_path):
        paths = []
        glosses_list = []

        with open(csv_path, 'r', encoding='utf-8') as file:
            for line in file:
                line = line.strip()
                if '|' not in line:
                    print(f"Skipping invalid line: {line}")
                    continue

                path, glosses = line.split('|', 1)
                paths.append(path)
                glosses_list.append(glosses)

        return paths, glosses_list

    # def read_bounding_box(self, path):
    #     bbox = {}
    #     with open(path, 'r', encoding='utf-8') as file:
    #         for line in file:
    #             parts = line.strip().split('|')
    #             if len(parts) != 2:
    #                 print(f"Invalid line: {line.strip()}")
    #                 continue

    #             video_path, coordinates = parts
    #             coords = {k: int(v) for k, v in (coord.strip().split(':') for coord in coordinates.split(','))}
    #             bbox[video_path] = coords

    #     return bbox

    def bounding_box_handler(self, path, video_path=None):
        if not hasattr(self, 'bbox_data_cache'):
            self.bbox_data_cache = {}
            with open(path, 'r', encoding='utf-8') as file:
                for line in file:
                    parts = line.strip().split('|')
                    if len(parts) != 2:
                        print(f"Invalid line: {line.strip()}")
                        continue

                    key, coordinates = parts
                    coords = {k: int(v) for k, v in (coord.strip().split(':') for coord in coordinates.split(','))}

                    self.bbox_data_cache[key] = coords
                    self.bbox_data_cache[os.path.basename(key)] = coords
                    self.bbox_data_cache[key.replace('/', '_')] = coords

        if video_path:
            possible_keys = [
                video_path,
                os.path.basename(video_path),
                video_path.replace('/', '_'),
                '_'.join(video_path.split('/')[-2:])
            ]

            for key in possible_keys:
                if key in self.bbox_data_cache:
                    return self.bbox_data_cache[key]

            return None

        return self.bbox_data_cache

    def get_transforms(self):
        if self.augmentation:
            return A.Compose([
                A.Resize(256, 256),
                A.RandomCrop(224, 224),
                A.HorizontalFlip(p=0.5),
                A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
                A.Normalize(mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225]),
            ])
        else:
            return A.Compose([A.Resize(224, 224)])

    # no augmentation
    # def get_transforms(self):
    #     return A.Compose([A.Resize(224, 224)])

    def get_labels(self, gloss_sequence):
        class_to_index = {word: i for i, word in enumerate(self.classes)}
        indexes = [class_to_index[word] for word in gloss_sequence.strip().split() if word in class_to_index]
        labels_tensor = torch.zeros(self.num_classes, dtype=torch.float32)
        labels_tensor[indexes] = 1.0
        return labels_tensor

    def load_video_sequence(self, path, img_type="jpg"):
        image_dir = os.path.join(self.data_path, path)
        images = sorted(glob.glob(os.path.join(image_dir, f'*.{img_type}')))
        total_frames = len(images)

        if total_frames == 0:
            raise FileNotFoundError(f"No frames found for video {path}")

        indices = np.linspace(0, total_frames - 1, self.seq_length, dtype=int)
        selected_images = [images[i] for i in indices]

        frames = []
        for img_path in selected_images:
            frame = cv2.imread(img_path)
            if frame is None:
                print(f"Warning: Failed to read frame {img_path}")
                continue
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

            bbox_coords = self.bounding_box_handler(self.config['bbox_filepath'], path)

            if bbox_coords:
                x1, y1, x2, y2 = bbox_coords['x1'], bbox_coords['y1'], bbox_coords['x2'], bbox_coords['y2']
                frame = frame[y1:y2, x1:x2]

            frame = self.transform(image=frame)['image']
            frame = self.to_tensor(image=frame)['image']
            frames.append(frame)

        return frames

    def process_frames(self, frames):
        if len(frames) < self.seq_length:
            frames = frames + [frames[-1]] * (self.seq_length - len(frames))
        elif len(frames) > self.seq_length:
            frames = frames[:self.seq_length]

        frames_tensor = torch.stack(frames).float()
        return frames_tensor

### Initialise DataLoader

In [29]:
classes_path = "D:\\GitHub\\FIT3164\\model\\files\\GSL_continuous\\classes.csv"
processor = VivitImageProcessor.from_pretrained(HUGGING_MODEL_NAME)

In [30]:
train_dataset = GSL_SI_ViViT(config, 'train', classes_path, processor)

8821 train instances


In [31]:
train_loader = DataLoader(
    train_dataset,
    batch_size=config["dataset"]["train"]["batch_size"],
    shuffle=config["dataset"]["train"]["shuffle"],
    num_workers=config["dataset"]["train"]["num_workers"],
    collate_fn=collate_fn,
    persistent_workers=True,
    pin_memory=True,
)

In [32]:
val_dataset = GSL_SI_ViViT(config, 'validation', classes_path, processor)

588 validation instances


In [33]:
val_loader = DataLoader(
    val_dataset,
    batch_size=config["dataset"]["validation"]["batch_size"],
    shuffle=config["dataset"]["validation"]["shuffle"],
    num_workers=config["dataset"]["validation"]["num_workers"],
    collate_fn=collate_fn,
    persistent_workers=True,
    pin_memory=True,
)

In [34]:
model_num_classes = train_dataset.num_classes

### Model

#### ViViT

In [35]:
class GSLViViT(pl.LightningModule):
    def __init__(self, num_classes, config):
        super(GSLViViT, self).__init__()
        self.save_hyperparameters()
        self.config = config
        self.num_classes = num_classes

        # self.vivit = VivitModel.from_pretrained(HUGGING_MODEL_NAME)

        self.vivit = VivitModel.from_pretrained(
            HUGGING_MODEL_NAME,
            hidden_dropout_prob=0.8,
            attention_probs_dropout_prob=0.8
        )

        self.vivit.config.patch_size = self.vivit.config.tubelet_size[1]
        self.vivit.config.num_labels = self.num_classes

        if hasattr(self.vivit, 'classifier'):
            self.vivit.classifier = nn.Identity()

        self.classifier = nn.Linear(self.vivit.config.hidden_size, self.num_classes)

        self.loss_fn = nn.BCEWithLogitsLoss()

        self.train_accuracy = MultilabelAccuracy(num_labels=num_classes)
        self.val_accuracy = MultilabelAccuracy(num_labels=num_classes)
        self.test_accuracy = MultilabelAccuracy(num_labels=num_classes)
        self.f1_score = MultilabelF1Score(num_labels=num_classes)
        self.precision = MultilabelPrecision(num_labels=num_classes)
        self.recall = MultilabelRecall(num_labels=num_classes)

        self.image_processor = VivitImageProcessor.from_pretrained(HUGGING_MODEL_NAME)

        self.freeze()

    def freeze(self):
        for name, param in self.vivit.named_parameters():
            if 'layernorm' in name or 'attention' in name:
                param.requires_grad = True
            else:
                param.requires_grad = False

    def forward(self, pixel_values):
        outputs = self.vivit(pixel_values=pixel_values, interpolate_pos_encoding=True)
        logits = self.classifier(outputs.pooler_output)
        return logits

    def training_step(self, batch, batch_idx):
        pixel_values, labels = self.process_batch(batch)
        logits = self(pixel_values)
        loss = self.loss_fn(logits, labels)

        preds = torch.sigmoid(logits)
        acc = self.train_accuracy(preds, labels.int())
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        pixel_values, labels = self.process_batch(batch)
        logits = self(pixel_values)
        loss = self.loss_fn(logits, labels)

        preds = torch.sigmoid(logits)
        acc = self.val_accuracy(preds, labels.int())
        f1 = self.f1_score(preds, labels.int())
        precision = self.precision(preds, labels.int())
        recall = self.recall(preds, labels.int())

        self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('val_acc', acc, on_step=True, on_epoch=True, prog_bar=True)
        self.log('val_f1', f1, on_step=True, on_epoch=True)
        self.log('val_precision', precision, on_step=True, on_epoch=True)
        self.log('val_recall', recall, on_step=True, on_epoch=True)

        return loss

    def test_step(self, batch, batch_idx):
        pixel_values, labels = self.process_batch(batch)
        logits = self(pixel_values)
        loss = self.loss_fn(logits, labels)

        preds = torch.sigmoid(logits)
        acc = self.test_accuracy(preds, labels.int())
        f1 = self.f1_score(preds, labels.int())
        precision = self.precision(preds, labels.int())
        recall = self.recall(preds, labels.int())

        self.log('test_loss', loss, on_step=True, on_epoch=True)
        self.log('test_acc', acc, on_step=True, on_epoch=True)
        self.log('test_f1', f1, on_step=True, on_epoch=True)
        self.log('test_precision', precision, on_step=True, on_epoch=True)
        self.log('test_recall', recall, on_step=True, on_epoch=True)

        return loss

    def configure_optimizers(self):
        optimizer = optim.AdamW(
            self.parameters(),
            lr=self.config["trainer"]["optimizer"]["lr"],
            weight_decay=self.config["trainer"]["optimizer"]["weight_decay"],
        )

        scheduler = optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=self.config["trainer"]["scheduler"]["max_lr"],
            total_steps=self.trainer.estimated_stepping_batches,
            pct_start=self.config["trainer"]["scheduler"]["pct_start"],
            anneal_strategy=self.config["trainer"]["scheduler"]["anneal_strategy"],
            cycle_momentum=self.config["trainer"]["scheduler"]["cycle_momentum"],
            div_factor=self.config["trainer"]["scheduler"]["div_factor"],
            final_div_factor=self.config["trainer"]["scheduler"]["final_div_factor"],
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
            },
        }

    def process_batch(self, batch):
        x, y = batch
        pixel_values = x.to(self.device)
        y = y.to(self.device)
        return pixel_values, y

### Setup Model

In [36]:
model = GSLViViT(num_classes=model_num_classes, config=config)

Some weights of VivitModel were not initialized from the model checkpoint at google/vivit-b-16x2 and are newly initialized: ['vivit.pooler.dense.bias', 'vivit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [37]:
checkpoint_path = 'D:\GitHub\FIT3164\model\last-v1.ckpt'

In [38]:
model = GSLViViT.load_from_checkpoint(checkpoint_path)
model.eval()

Some weights of VivitModel were not initialized from the model checkpoint at google/vivit-b-16x2 and are newly initialized: ['vivit.pooler.dense.bias', 'vivit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


GSLViViT(
  (vivit): VivitModel(
    (embeddings): VivitEmbeddings(
      (patch_embeddings): VivitTubeletEmbeddings(
        (projection): Conv3d(3, 768, kernel_size=(2, 16, 16), stride=(2, 16, 16))
      )
      (dropout): Dropout(p=0.8, inplace=False)
    )
    (encoder): VivitEncoder(
      (layer): ModuleList(
        (0-11): 12 x VivitLayer(
          (attention): VivitAttention(
            (attention): VivitSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.8, inplace=False)
            )
            (output): VivitSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.8, inplace=False)
            )
          )
          (intermediate): VivitIntermediate(
            (dense): Linear(in

In [39]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

GSLViViT(
  (vivit): VivitModel(
    (embeddings): VivitEmbeddings(
      (patch_embeddings): VivitTubeletEmbeddings(
        (projection): Conv3d(3, 768, kernel_size=(2, 16, 16), stride=(2, 16, 16))
      )
      (dropout): Dropout(p=0.8, inplace=False)
    )
    (encoder): VivitEncoder(
      (layer): ModuleList(
        (0-11): 12 x VivitLayer(
          (attention): VivitAttention(
            (attention): VivitSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.8, inplace=False)
            )
            (output): VivitSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.8, inplace=False)
            )
          )
          (intermediate): VivitIntermediate(
            (dense): Linear(in

In [40]:
from pytorch_lightning.utilities.model_summary import summarize

In [41]:
model_summary = summarize(model, max_depth=-1)
print(model_summary)

    | Name                                                    | Type                   | Params | Mode
------------------------------------------------------------------------------------------------------------
0   | vivit                                                   | VivitModel             | 89.2 M | eval
1   | vivit.embeddings                                        | VivitEmbeddings        | 3.6 M  | eval
2   | vivit.embeddings.patch_embeddings                       | VivitTubeletEmbeddings | 1.2 M  | eval
3   | vivit.embeddings.patch_embeddings.projection            | Conv3d                 | 1.2 M  | eval
4   | vivit.embeddings.dropout                                | Dropout                | 0      | eval
5   | vivit.encoder                                           | VivitEncoder           | 85.1 M | eval
6   | vivit.encoder.layer                                     | ModuleList             | 85.1 M | eval
7   | vivit.encoder.layer.0                                   | Viv

In [56]:
import torch
import cv2
import numpy as np
import time
from collections import deque
import yaml

def load_config(config_path):
    with open(config_path, "r") as file:
        config = yaml.safe_load(file)
    return config

# Load the config and checkpoint
config = load_config(r"D:\GitHub\FIT3164\model\configs\dummy.yaml")
checkpoint_path = r'D:\GitHub\FIT3164\model\last-v1.ckpt'

# Load the model
model = GSLViViT.load_from_checkpoint(checkpoint_path)
model.eval()

# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Constants
SEQUENCE_LENGTH = 32  # From the model config
FRAME_SIZE = 224
PATCH_SIZE = 16

# Create id2label mapping
classes_path = config['dataset']['classes_filepath']
dataset = GSL_SI_ViViT(config, 'train', classes_path, None)
id2label = {i: label for i, label in enumerate(dataset.classes)}

def preprocess_frame(frame):
    frame = cv2.resize(frame, (FRAME_SIZE, FRAME_SIZE))
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frame = frame.astype(np.float32) / 255.0
    frame = (frame - 0.45) / 0.225
    return frame

def process_sequence(frames):
    if len(frames) < SEQUENCE_LENGTH:
        frames = frames + [frames[-1]] * (SEQUENCE_LENGTH - len(frames))
    elif len(frames) > SEQUENCE_LENGTH:
        frames = frames[:SEQUENCE_LENGTH]
    
    frames = np.stack(frames)
    
    # Reshape to (batch_size, num_frames, num_channels, height, width)
    frames = frames.transpose(0, 3, 1, 2)
    frames = torch.from_numpy(frames).unsqueeze(0)
    
    return frames.to(device)

# Initialize video capture
cap = cv2.VideoCapture(2)

frame_buffer = deque(maxlen=SEQUENCE_LENGTH)
recording = False
predicted_class = "No prediction yet"

try:
    while True:
        ret, frame = cap.read()
        if not ret:
            break

        # Display the frame
        display_frame = frame.copy()
        cv2.putText(display_frame, f"Predicted: {predicted_class}", (10, 30), 
                    cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
        cv2.putText(display_frame, "Press 'r' to start/stop recording, 'q' to quit", 
                    (10, display_frame.shape[0] - 10), cv2.FONT_HERSHEY_SIMPLEX, 
                    0.5, (255, 255, 255), 1)
        
        cv2.imshow('ViViT GSL Recognition', display_frame)

        key = cv2.waitKey(1) & 0xFF
        if key == ord('q'):
            break
        elif key == ord('r'):
            recording = not recording
            if recording:
                print("Recording started")
                frame_buffer.clear()
            else:
                print("Recording stopped, processing...")
                if len(frame_buffer) > 0:
                    input_tensor = process_sequence(list(frame_buffer))
                    print(f"Input tensor shape: {input_tensor.shape}")
                    
                    try:
                        with torch.no_grad():
                            outputs = model(input_tensor)
                            probabilities = torch.sigmoid(outputs)
                        
                        top_probs, top_classes = torch.topk(probabilities, k=5, dim=1)
                        predicted_classes = []
                        for i in range(3):  # Only iterate over top 5
                            class_name = id2label[top_classes[0][i].item()]
                            prob = top_probs[0][i].item()
                            predicted_classes.append(f"{class_name} ({prob:.2f})")
                        predicted_class = " | ".join(predicted_classes)
                        print(f"Predicted Sentence:")
                        for class_prob in predicted_classes:
                            print(class_prob)
                        print(f"\nPredicted classes: {predicted_class}")
                    except Exception as e:
                        print(f"Error during prediction: {str(e)}")
                        print(f"Model input shape: {input_tensor.shape}")
                else:
                    print("No frames recorded")

        if recording:
            processed_frame = preprocess_frame(frame)
            frame_buffer.append(processed_frame)

        time.sleep(0.01)  # Small delay to control frame rate

except KeyboardInterrupt:
    print("Interrupted by user")

finally:
    cap.release()
    cv2.destroyAllWindows()
    print("Video stream ended")

Some weights of VivitModel were not initialized from the model checkpoint at google/vivit-b-16x2 and are newly initialized: ['vivit.pooler.dense.bias', 'vivit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


8821 train instances
Recording started
Recording stopped, processing...
Input tensor shape: torch.Size([1, 32, 3, 224, 224])
Predicted Sentence:
ΕΓΩ(1) (0.26)
ΕΣΥ (0.24)
ΚΑΛΟ (0.11)

Predicted classes: ΕΓΩ(1) (0.26) | ΕΣΥ (0.24) | ΚΑΛΟ (0.11)
Video stream ended
