# FIT3164 - MDS08


## Setup


In [None]:
import os
from google.colab import drive

drive.mount("/content/drive")

Mounted at /content/drive


In [None]:
!pip cache purge

In [None]:
# uncomment if using cpu training
# !pip install torch torchvision torchaudio

In [None]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
!pip install pytorch-lightning tensorboardX omegaconf transformers

In [None]:
os.chdir("drive/Shareddrives/FIT3164/model")
!chmod -R 755 base configs files FIT3164_GSL.ipynb logger logs models trainer utils

## Fine-tuning


In [None]:
# all imports
# from GPUtil import showUtilization as gpu_usage
from numba import cuda


import cv2
import logging


import numpy as np
import torch
from torch import nn, optim


from torch.utils.data import Dataset, DataLoader


from torchvision import transforms


from sklearn.model_selection import train_test_split


from PIL import Image
import yaml
import csv
from torch.nn.utils.rnn import pad_sequence


import pytorch_lightning as pl


from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping


from pytorch_lightning.loggers import TensorBoardLogger


from transformers import ViTImageProcessor, ViTModel, ViTConfig


import tensorboard


import tensorboardX


import yaml


from omegaconf import OmegaConf


import glob
import random


from base.base_loader import BaseDataset


from datasets.loader_utils import (
    multi_label_to_index,
    pad_video,
    video_transforms,
    sampling,
    VideoRandomResizedCrop,
    read_gsl_continuous,
    gsl_context,
    read_bounding_box,
)

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import ViTModel, ViTConfig
import torch.nn.functional as F

import gc

import psutil
import humanize
import os
# import GPUtil as GPU

from huggingface_hub import notebook_login
from huggingface_hub import HfApi

In [None]:
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"

In [None]:
# check if cuda is available
print(torch.cuda.is_available())

# uncomment and run this cell to use cpu
# torch.cuda.is_available = lambda : False
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# check if cuda is available (should say False if the above cell is run)
print(torch.cuda.is_available())

In [None]:
pl.seed_everything(42)

In [None]:
HUGGING_MODEL_NAME = "google/vit-base-patch16-224-in21k"
# HUGGING_MODEL_NAME = 'google/vit-large-patch16-224-in21k'
# HUGGING_MODEL_NAME = 'google/vit-huge-patch14-224-in21k'

In [None]:
image_processor = ViTImageProcessor.from_pretrained(HUGGING_MODEL_NAME)

In [None]:
# too much work to remove args from the code, so just create a dummy class and pass it to the functions
class Args:

    def __init__(self, config):
        self.cwd = os.getcwd()

        self.input_data = config["dataset"]["input_data"]

        self.return_context = False

In [None]:
feats_path = "gsl_cont_features/"


train_prefix = "train"


dev_prefix = "dev"


validation_prefix = "validation"


test_prefix = "test"


train_filepath = "files/GSL_continuous/gsl_split_SI_train.csv"


dev_filepath = "files/GSL_continuous/gsl_split_SI_dev.csv"


validation_filepath = "files/GSL_continuous/gsl_split_SI_dev.csv"


test_filepath = "files/GSL_continuous/gsl_split_SI_test.csv"

In [None]:
class GSL_SI(BaseDataset):

    def __init__(self, config, args, mode, classes):
        """






        Args:





            config:





            args:





            mode:





            classes:
        """

        super(GSL_SI, self).__init__(config, args, mode, classes)

        self.config = config["dataset"]
        self.mode = mode

        self.dim = tuple(self.config["dim"])
        self.classes = classes

        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}

        self.num_classes = self.config["classes"]

        self.seq_length = self.config[self.mode]["seq_length"]

        self.normalize = self.config["normalize"]

        self.padding = self.config["padding"]

        self.augmentation = self.config[self.mode]["augmentation"]

        self.return_context = args.return_context

        self.rgb_data_path = os.path.join(
            self.config["input_data"], self.config["images_path"]
        )

        self.depth_data_path = os.path.join(
            self.config["input_data"], self.config["depth_path"]
        )

        if self.mode == train_prefix:

            self.list_IDs, self.list_glosses = read_gsl_continuous(
                os.path.join(args.cwd, train_filepath)
            )

        elif self.mode == validation_prefix:

            self.list_IDs, self.list_glosses = read_gsl_continuous(
                os.path.join(args.cwd, validation_filepath)
            )

        elif self.mode == test_prefix:

            self.list_IDs, self.list_glosses = read_gsl_continuous(
                os.path.join(args.cwd, test_filepath)
            )

        print(f"{len(self.list_IDs)} {self.mode} instances")

        self.bbox = read_bounding_box(
            os.path.join(args.cwd, "files/GSL_continuous/bbox_for_gsl_continuous.txt")
        )

        self.context = gsl_context(self.list_IDs, self.list_glosses)

        if self.config["modality"] == "full":

            self.data_path = os.path.join(
                self.config["input_data"], self.config["images_path"]
            )

            self.get = self.video_loader

        elif self.config["modality"] == "features":

            self.data_path = os.path.join(
                self.config["input_data"], self.config["features_path"]
            )

            self.get = self.feature_loader

    def __len__(self):

        return len(self.list_IDs)

    def __getitem__(self, index):
        file_path = self.list_IDs[index]

        rgb_x = self.load_video_sequence(path=self.list_IDs[index], img_type="jpg")

        # depth_x = self.load_video_sequence(
        #     path=self.list_IDs[index].replace("color", "depth"), img_type="jpg"
        # )

        y = multi_label_to_index(
            classes=self.classes, target_labels=self.list_glosses[index]
        )
        print(f"  File path: {file_path}")
        print(f"In __getitem__ (index {index}): RGB shape: {rgb_x.shape}")
        # print(f"In __getitem__ (index {index}): Depth shape: {depth_x.shape}")
        print(f"In __getitem__ (index {index}): Label shape: {y.shape}")

        # return rgb_x, depth_x, y
        return rgb_x, y

    def feature_loader(self, index):

        folder_path = os.path.join(self.data_path, self.list_IDs[index])

        # print(folder_path)

        y = multi_label_to_index(
            classes=self.classes, target_labels=self.list_glosses[index]
        )

        if self.context[index] != None:

            c = multi_label_to_index(
                classes=self.classes, target_labels=self.context[index]
            )

        else:

            c = torch.tensor([0], dtype=torch.int)

        x = torch.FloatTensor(np.load(folder_path + ".npy")).squeeze(0)

        if self.return_context:

            return x, [y, c]

        return x, y

    def video_loader(self, index):

        x = self.load_video_sequence(path=self.list_IDs[index], img_type="jpg")

        y = multi_label_to_index(
            classes=self.classes, target_labels=self.list_glosses[index]
        )

        return x, y

    def load_video_sequence(self, path, img_type="jpg", is_depth=False):

        data_path = self.depth_data_path if is_depth else self.rgb_data_path

        images = sorted(
            glob.glob(
                os.path.join(
                    data_path,
                    path,
                )
                + "/*"
                + img_type
            )
        )

        h_flip = False

        img_sequence = []

        # print(images)

        if len(images) < 1:

            print(os.path.join(self.data_path, path))

        bbox = self.bbox.get(path)

        if self.augmentation:

            ## training set temporal  AUGMENTATION

            temporal_augmentation = int(
                (np.random.randint(80, 100) / 100.0) * len(images)
            )

            if temporal_augmentation > 15:

                images = sorted(random.sample(images, k=temporal_augmentation))

            if len(images) > self.seq_length:

                # random frame sampling

                images = sorted(random.sample(images, k=self.seq_length))

        else:

            # test uniform sampling

            if len(images) > self.seq_length:

                images = sorted(sampling(images, self.seq_length))

        i = np.random.randint(0, 30)

        j = np.random.randint(0, 30)

        brightness = 1 + random.uniform(-0.2, +0.2)

        contrast = 1 + random.uniform(-0.2, +0.2)

        hue = random.uniform(0, 1) / 10.0

        # r_resize = (112, 112)

        r_resize = (224, 224)

        crop_or_bbox = random.uniform(0, 1) > 0.5

        to_flip = random.uniform(0, 1) > 1

        grayscale = random.uniform(0, 1) > 0.9

        t1 = VideoRandomResizedCrop(self.dim[0], scale=(0.9, 1.0), ratio=(0.8, 1.2))

        for img_path in images:

            frame_o = Image.open(img_path).convert("RGB")

            crop_size = 120

            ## CROP BOUNDING BOX

            ## CROP BOUNDING BOX

            frame1 = np.array(frame_o)

            frame1 = frame1[:, crop_size : 648 - crop_size]

            frame = Image.fromarray(frame1)

            if self.augmentation:

                ## training set DATA AUGMENTATION

                frame = frame.resize(r_resize)

                img_tensor = video_transforms(
                    img=frame,
                    i=i,
                    j=j,
                    bright=brightness,
                    cont=contrast,
                    h=hue,
                    dim=self.dim,
                    resized_crop=t1,
                    augmentation=True,
                    normalize=self.normalize,
                    crop=crop_or_bbox,
                    to_flip=to_flip,
                    grayscale=grayscale,
                )

                img_sequence.append(img_tensor)

            else:

                # TEST set  NO DATA AUGMENTATION

                if is_depth:

                    # Resize depth frame to match RGB dimensions

                    frame = frame.resize(self.dim)

                else:

                    frame = frame.resize(self.dim)

                img_tensor = video_transforms(
                    img=frame,
                    i=i,
                    j=j,
                    bright=1,
                    cont=1,
                    h=0,
                    dim=self.dim,
                    augmentation=False,
                    normalize=self.normalize,
                )

                img_sequence.append(img_tensor)

        pad_len = self.seq_length - len(images)

        X1 = torch.stack(img_sequence).float()

        if self.padding:

            X1 = pad_video(X1, padding_size=pad_len, padding_type="zeros")

        if len(images) < 25:

            X1 = pad_video(X1, padding_size=25 - len(images), padding_type="zeros")

        return X1

In [None]:
class GSLDualViT(pl.LightningModule):
    def __init__(self, num_classes, config):
        super().__init__()
        self.config = config
        self.class_names = class_names
        self.num_classes = len(class_names)
        # self.logger = logger
        self.learning_rate = config["trainer"]["optimizer"]["lr"]
        self.sequence_length = config["dataset"]["train"]["seq_length"]

        # RGB ViT
        self.rgb_vit = ViTModel.from_pretrained(HUGGING_MODEL_NAME)
        self.rgb_vit.train()

        for param in self.rgb_vit.parameters():
            param.requires_grad = True

        # # Depth ViT (initialize with same config but different input channels) # concat better
        # depth_config = ViTConfig.from_pretrained(HUGGING_MODEL_NAME)
        # depth_config.num_channels = 3
        # self.depth_vit = ViTModel(depth_config)

        # # Fusion layer
        # self.fusion = nn.Linear(
        # self.rgb_vit.config.hidden_size * 2, self.rgb_vit.config.hidden_size
        # )

        # Temporal modeling (remove)  # swap it out with something else (e.g. get feature maps (feature extractor) and save it before you pass to the transformer from backbone efficientnet or resnet and pass to transformer)
        self.lstm = nn.LSTM(
            self.rgb_vit.config.hidden_size,
            self.config["trainer"]["model"]["backbone"]["rnn"]["hidden_size"],
            num_layers=self.config["trainer"]["model"]["backbone"]["rnn"]["num_layers"],
            bidirectional=self.config["trainer"]["model"]["backbone"]["rnn"][
                "bidirectional"
            ],
            batch_first=True,
            dropout=self.config["trainer"]["model"]["backbone"]["rnn"]["dropout"],
        )

        # Classification head (need to change)
        lstm_output_size = self.config["trainer"]["model"]["backbone"]["rnn"][
            "hidden_size"
        ]
        if self.config["trainer"]["model"]["backbone"]["rnn"]["bidirectional"]:
            lstm_output_size *= 2

        self.classifier = nn.Linear(lstm_output_size, num_classes)

    def forward(self, rgb_sequence, depth_sequence=None):
        print(
            f"Forward pass input shapes: RGB {rgb_sequence.shape}, Depth {depth_sequence.shape if depth_sequence is not None else None}"
        )
        batch_size, seq_len, channels, height, width = rgb_sequence.shape
        features = []
        for i in range(seq_len):
            if (
                i == 0 or i == seq_len - 1 or i % 20 == 0
            ):  # Log first, last, and every 20th frame
                print(f"Processing frame {i}")
            try:
                rgb_features = self.rgb_vit(rgb_sequence[:, i]).last_hidden_state[:, 0]
                print(f"RGB features shape: {rgb_features.shape}")
                # if depth_sequence is not None:
                #     depth_features = self.depth_vit(
                #         depth_sequence[:, i]
                #     ).last_hidden_state[:, 0]
                #     print(f"Depth features shape: {depth_features.shape}")
                #     fused = self.fusion(
                #         torch.cat([rgb_features, depth_features], dim=-1)
                #     )
                # else:
                #     fused = rgb_features
                # print(f"Fused features shape: {fused.shape}")
                features.append(rgb_features)
            except Exception as e:
                print(f"Error processing frame {i}: {str(e)}")
        print(f"Number of features: {len(features)}")
        if len(features) == 0:
            print("Features is empty!")
            return None
        sequence = torch.stack(features, dim=1)
        print(f"Sequence shape: {sequence.shape}")
        lstm_out, _ = self.lstm(sequence)
        print(f"LSTM output shape: {lstm_out.shape}")
        final_feature = lstm_out[:, -1, :]
        output = self.classifier(final_feature)
        print(f"Final output shape: {output.shape}")
        return output

    def training_step(self, batch, batch_idx):
        # rgb_sequence, depth_sequence, labels = batch
        rgb_sequence, labels = batch
        print(f"RGB shape: {rgb_sequence.shape}")
        # print(f"Depth shape: {depth_sequence.shape}")
        print(f"Labels shape: {labels.shape}")

        # logits = self(rgb_sequence, depth_sequence)
        logits = self(rgb_sequence)
        print(f"Logits shape: {logits.shape if logits is not None else None}")

        if logits is None:
            print("Logits is None! Check the forward method.")
            return None

        # # Assuming you've added a classifier layer to handle the output from the forward method
        # # If not, you might need to add: self.classifier = nn.Linear(vit_output_dim, num_classes)
        # logits = self.classifier(
        #     logits[:, -1]
        # )  # Using the last frame's features for classification

        # Ensure labels are 1D
        labels = labels.view(-1)
        labels = labels[: logits.shape[0]]
        print(f"Reshaped Labels shape: {labels.shape}")

        loss = nn.functional.cross_entropy(logits, labels) # maybe change to CTC loss
        preds = torch.argmax(logits, dim=1)
        acc = (preds == labels).float().mean()

        self.log(
            "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )
        self.log(
            "train_acc", acc, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )

        return loss

    def validation_step(self, batch, batch_idx):
        # rgb_sequence, depth_sequence, labels = batch
        rgb_sequence, labels = batch
        # logits = self(rgb_sequence, depth_sequence)
        logits = self(rgb_sequence)

        labels = labels.view(-1)
        labels = labels[: logits.shape[0]]
        print(f"Reshaped Labels shape: {labels.shape}")

        loss = nn.functional.cross_entropy(logits, labels)

        preds = torch.argmax(logits, dim=1)
        acc = (preds == labels).float().mean()

        self.log(
            "val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )
        self.log(
            "val_acc", acc, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )

        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(
            self.parameters(),
            lr=self.learning_rate,
            weight_decay=self.config["trainer"]["optimizer"]["weight_decay"],
        )
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            factor=self.config["trainer"]["scheduler"]["scheduler_factor"],
            patience=self.config["trainer"]["scheduler"]["scheduler_patience"],
            min_lr=self.config["trainer"]["scheduler"]["scheduler_min_lr"],
            verbose=True,
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler,
            "monitor": "val_loss",
        }

In [None]:
def load_config(config_path):
    with open(config_path, "r") as file:
        config = yaml.safe_load(file)
    return config


config = load_config("configs/dummy.yaml")

In [None]:
def load_classes(file_path):
    with open(file_path, "r", encoding="utf-8") as f:
        reader = csv.reader(f)
        classes = [row[0] for row in reader]
    return classes


class_names = load_classes(config["dataset"]["classes_filepath"])

In [None]:
def collate_fn(batch):
    # rgb_sequences, depth_sequences, labels = zip(*batch)
    rgb_sequences, labels = zip(*batch)

    rgb_padded = pad_sequence(rgb_sequences, batch_first=True)
    # depth_padded = pad_sequence(depth_sequences, batch_first=True)

    max_label_length = max(len(label) for label in labels)

    padded_labels = []
    for label in labels:
        padded_label = label.tolist() + [-1] * (max_label_length - len(label))
        padded_labels.append(torch.tensor(padded_label))

    labels = torch.stack(padded_labels)

    print(f"Collate function: RGB shape: {rgb_padded.shape}, Labels shape: {labels.shape}")

    # return rgb_padded, depth_padded, labels
    return rgb_padded, labels

In [None]:
args = Args(config)

train_dataset = GSL_SI(config, args, mode="train", classes=class_names)
val_dataset = GSL_SI(config, args, mode="validation", classes=class_names)

train_loader = DataLoader(
    train_dataset,
    batch_size=config["dataset"]["train"]["batch_size"],
    shuffle=config["dataset"]["train"]["shuffle"],
    num_workers=config["dataset"]["train"]["num_workers"],
    collate_fn=collate_fn,
    persistent_workers=True,
    pin_memory=True,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=config["dataset"]["validation"]["batch_size"],
    shuffle=config["dataset"]["validation"]["shuffle"],
    num_workers=config["dataset"]["validation"]["num_workers"],
    collate_fn=collate_fn,
    persistent_workers=True,
    pin_memory=True,
)

In [None]:
print(f"Batch size: {train_loader.batch_size}")
print(f"Workers: {train_loader.num_workers}")

In [None]:
model = GSLDualViT(num_classes=config["dataset"]["classes"], config=config)

In [None]:
model = torch.compile(model)

In [None]:
# trainer setup and initialization
trainer = pl.Trainer(
    max_epochs=config["trainer"]["epochs"],
    # accelerator='cpu' if config['trainer']['cuda'] else 'cpu', # uncomment this line to use CPU, and comment the next line
    accelerator="gpu" if config["trainer"]["cuda"] else "cpu",
    devices=1,
    precision="bf16-mixed",
    num_sanity_val_steps=0,
    callbacks=[
        ModelCheckpoint(
            dirpath="/content/drive/Shareddrives/checkpoints/gsl_checkpoints",
            filename="gsl-dualvit-{epoch:02d}-{val_loss:.2f}",
            save_top_k=3,
            monitor="val_loss",
            mode="min",
            save_last=True,
            every_n_train_steps=1,
        ),
        EarlyStopping(monitor="val_loss", patience=1, mode="min"),
    ],
    enable_progress_bar=True,
    log_every_n_steps=1,
    logger=TensorBoardLogger("/content/drive/Shareddrives/checkpoints/logs", name=config["trainer"]["logger"]),
    enable_checkpointing=True,
)

## Pre-training

In [None]:
gc.collect()

In [None]:
# # get available memory info
# torch.cuda.mem_get_info()

# # memory footprint support libraries/code
# !ln -sf /opt/bin/nvidia-smi /usr/bin/nvidia-smi

# GPUs = GPU.getGPUs()
# # XXX: only one GPU on Colab and isn’t guaranteed
# gpu = GPUs[0]


# def printm():
#     process = psutil.Process(os.getpid())
#     print(
#         "Gen RAM Free: " + humanize.naturalsize(psutil.virtual_memory().available),
#         " |     Proc size: " + humanize.naturalsize(process.memory_info().rss),
#     )
#     print(
#         "GPU RAM Free: {0:.0f}MB | Used: {1:.0f}MB | Util {2:3.0f}% | Total     {3:.0f}MB".format(
#             gpu.memoryFree, gpu.memoryUsed, gpu.memoryUtil * 100, gpu.memoryTotal
#         )
#     )


# # free up gpu memory before training
# def free_gpu_cache():

#     print("Initial GPU Usage")
#     gpu_usage()

#     torch.cuda.empty_cache()

#     cuda.select_device(0)

#     cuda.close()

#     cuda.select_device(0)

#     print("GPU Usage after emptying the cache")
#     gpu_usage()


# free_gpu_cache()
# printm()

In [None]:
# !nvidia-smi

In [None]:
# model.zero_grad()

In [None]:
!sudo chmod -R 777 logs/train_CSLR/
!sudo chown -R $(whoami) logs/train_CSLR/
!sudo chown root:root logs/train_CSLR/

In [None]:
!sudo chmod -R 777 /content/drive/Shareddrives/checkpoints
!sudo chown -R $(whoami) /content/drive/Shareddrives/checkpoints
!sudo chown root:root /content/drive/Shareddrives/checkpoints

In [None]:
torch.set_float32_matmul_precision("medium")

## Training

In [None]:
# start training
trainer.fit(model, train_loader, val_loader)
# trainer.fit(model, train_loader, val_loader, ckpt_path="/content/drive/Shareddrives/checkpoints/gsl_checkpoints/last.ckpt")

## Model Summary

In [None]:
net = model
modules = [module for module in net.modules()]
params = [param.shape for param in net.parameters()]

print(modules[0])
total_params=0
for i in range(1,len(modules)):
   j = 2*i
   param = (params[j-2][1]*params[j-2][0])+params[j-1][0]
   total_params += param
   print("Layer",i,"->\t",end="")
   print("Weights:", params[j-2][0],"x",params[j-2][1],
         "\tBias: ",params[j-1][0], "\tParameters: ", param)
print("\nTotal Params: ", total_params)

In [None]:
from torchinfo import summary

batch_size = 1
summary(model, input_size=(batch_size, 1, 28, 28))

## Evaluate on test

In [None]:
pass

## Predictions

In [None]:
pass

## Visualisation

In [None]:
# confusion matrix and f1 score
pass

In [None]:
# plot accuracy against epoch (training, test, dev)
pass

In [None]:
# plot loss against epoch (training, test, dev)
pass

## Save to Hugging Face

In [None]:
trainer.save_model("/content/drive/Shareddrives/checkpoints/gsl_checkpoints")

In [None]:
notebook_login()

In [None]:
model_name = "gsl-dualvit"
api = HfApi()
repo_id = f"fawxyz/{model_name}"
try:
    api.create_repo(repo_id)
    print(f"Repo {repo_id} created")
except:
    print(f"Repo {repo_id} already exists")

In [None]:
api.upload_folder(
    folder_path=model_name, path_in_repo=".", repo_id=repo_id, repo_type="model"
)

## Testing model from Hugging Face

In [None]:
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import torch
import numpy as np
import requests
from IPython.display import display

url = "https://media.springernature.com/lw685/springer-static/image/chp%3A10.1007%2F978-3-031-12638-3_33/MediaObjects/527350_1_En_33_Fig11_HTML.png"
image = Image.open(requests.get(url, stream=True).raw)

feature_extractor = ViTFeatureExtractor.from_pretrained('fawxyz/gsl-dualvit')
model = ViTForImageClassification.from_pretrained('fawxyz/gsl-dualvit')

inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
# model predicts one of the 1000 ImageNet classes
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])

In [None]:
url = "https://media.springernature.com/lw685/springer-static/image/chp%3A10.1007%2F978-3-031-12638-3_33/MediaObjects/527350_1_En_33_Fig11_HTML.png"
image = Image.open(requests.get(url, stream=True).raw)

print("Original Image:")
display(image)

image_size = (224, 224)
image_resized = image.resize(image_size)

print("Resized Image:")
display(image_resized)

feature_extractor = ViTFeatureExtractor.from_pretrained('fawxyz/gsl-dualvit')
model = ViTForImageClassification.from_pretrained('fawxyz/gsl-dualvit')

image_array = np.array(image_resized.convert('RGB'))

image_tensor = torch.from_numpy(image_array).permute(2, 0, 1).float()

image_tensor = image_tensor.unsqueeze(0) / 255.0

with torch.no_grad():
    outputs = model(image_tensor)

logits = outputs.logits

predicted_class_idx = logits.argmax(-1).item()
predicted_class_label = model.config.id2label[predicted_class_idx]

print(f"\nPredicted class index: {predicted_class_idx}")
print(f"Predicted class label: {predicted_class_label}")

top_5 = logits.topk(5)
for i in range(5):
    class_idx = top_5.indices[0][i].item()
    class_label = model.config.id2label[class_idx]
    probability = torch.softmax(logits, dim=1)[0][class_idx].item()
    print(f"Top {i+1}: {class_label} (Index: {class_idx}, Probability: {probability:.4f})")