In [1]:
!module load cuda/12.2.0
!nvidia-smi

!nvcc --version
!conda info --envs

import sys
print("Environment:", sys.executable)

import torch

print(torch.__version__)
print(torch.version.cuda)
print("Is GPU available:", torch.cuda.is_available())

Loading module 'cuda/12.2.0'[m
Loading module 'gcc/9.2.0'[m
[m
Loading [1mcuda/12.2.0[22m[m
  [94mLoading requirement[0m: gcc/9.2.0[m
[K[?1l>Sat Dec 14 20:59:32 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100 80GB PCIe          Off | 00000000:E2:00.0 Off |                    0 |
| N/A   32C    P0              63W / 300W |  31512MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+----

In [2]:
# General
import os
import numpy as np
import pandas as pd
import re
import zipfile
import matplotlib.pyplot as plt

# PyTorch
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import torch.nn as nn
from torch.nn import functional as F

# Kaggle API configuration
kaggle_json_path = "kaggle.json"
if not os.path.exists(kaggle_json_path):
    raise FileNotFoundError("kaggle.json not found in the current directory.")

# Create the ~/.kaggle directory if it doesn't exist
os.makedirs(os.path.expanduser("~/.kaggle"), exist_ok=True)

# Move kaggle.json to the ~/.kaggle directory
os.system(f"cp {kaggle_json_path} ~/.kaggle/")
os.system("chmod 600 ~/.kaggle/kaggle.json")

# Specify the dataset name
dataset_name = "williamscott701/memotion-dataset-7k"

# Add Kaggle CLI to PATH if necessary
kaggle_path = os.path.expanduser("~/.conda/envs/final_ml/bin/kaggle")
os.environ["PATH"] += os.pathsep + os.path.dirname(kaggle_path)

# Verify Kaggle CLI installation
kaggle_version = os.popen(f"{kaggle_path} --version").read()
if "Kaggle API" not in kaggle_version:
    raise EnvironmentError("Kaggle CLI is not properly installed or accessible.")

# Download the dataset
download_status = os.system(f"{kaggle_path} datasets download -d {dataset_name}")
if download_status != 0:
    raise RuntimeError("Failed to download dataset using Kaggle CLI.")

# Unzip the dataset
zip_file = "memotion-dataset-7k.zip"
if os.path.exists(zip_file):
    with zipfile.ZipFile(zip_file, "r") as zip_ref:
        zip_ref.extractall("memotion-dataset")
    print("Dataset extracted to 'memotion-dataset/'")
else:
    raise FileNotFoundError(f"{zip_file} not found after Kaggle download.")

# Verify dataset contents
dataset_folder = "memotion-dataset"
if not os.path.exists(dataset_folder):
    raise FileNotFoundError("Dataset folder not found after extraction.")
else:
    print(f"Dataset successfully prepared in '{dataset_folder}'.")

print("Setup complete!")


Dataset URL: https://www.kaggle.com/datasets/williamscott701/memotion-dataset-7k
License(s): other
memotion-dataset-7k.zip: Skipping, found more recently modified local copy (use --force to force download)
Dataset extracted to 'memotion-dataset/'
Dataset successfully prepared in 'memotion-dataset'.
Setup complete!


In [3]:
unprocessed_data = pd.read_csv('memotion-dataset/memotion_dataset_7k/labels.csv')
# Verify that the file is loaded properly
print(unprocessed_data.head())

# # Select the desired columns: the text in the image, and the label for offensivness
# processed_data = unprocessed_data["offensive"]

# # Map the labels in the column "offensive"
# labels_map = {"not_offensive":0, "slight":1, "very_offensive":2, "hateful_offensive":2}
# processed_data = processed_data.map(labels_map)


# Select the desired columns: the text in the image, and the label for offensivness
processed_data = unprocessed_data["humour"]

# Map the labels in the column "humour"
labels_map = {"funny":1, "very_funny":2, "not_funny":0, "hilarious":2}
processed_data = processed_data.map(labels_map)

# Verify by printing the first few values and by checking the number of rows
print(processed_data.head())
print(processed_data.shape) # should be equal to number of images

# Save the processed data to avoid running this cell multiple times
processed_data.to_csv('data.csv')

   Unnamed: 0    image_name  \
0           0   image_1.jpg   
1           1  image_2.jpeg   
2           2   image_3.JPG   
3           3   image_4.png   
4           4   image_5.png   

                                            text_ocr  \
0  LOOK THERE MY FRIEND LIGHTYEAR NOW ALL SOHALIK...   
1  The best of #10 YearChallenge! Completed in le...   
2  Sam Thorne @Strippin ( Follow Follow Saw every...   
3              10 Year Challenge - Sweet Dee Edition   
4  10 YEAR CHALLENGE WITH NO FILTER 47 Hilarious ...   

                                      text_corrected      humour  \
0  LOOK THERE MY FRIEND LIGHTYEAR NOW ALL SOHALIK...   hilarious   
1  The best of #10 YearChallenge! Completed in le...   not_funny   
2  Sam Thorne @Strippin ( Follow Follow Saw every...  very_funny   
3              10 Year Challenge - Sweet Dee Edition  very_funny   
4  10 YEAR CHALLENGE WITH NO FILTER 47 Hilarious ...   hilarious   

           sarcasm       offensive      motivational overall_sentim

In [4]:
# Extract the texts
texts = unprocessed_data['text_corrected'].tolist()
print(len(texts))

import re

# normalize input text
def normalize(text):
    # Lowercase and split text into words
    #return [word.lower().strip('.,!?') for word in text.split()]
    return re.findall(r"\w+|[^\w\s]", text.lower())

# Extract the targets
labels = processed_data.tolist()  # Convert to list for easier indexing

# Find indices of empty or invalid texts
disregarded_indices = [i for i, text in enumerate(texts) if not isinstance(text, str) or not text.strip() or i == 5118]

# Remove texts and corresponding targets
texts = [text for i, text in enumerate(texts) if i not in disregarded_indices]
labels = [label for i, label in enumerate(labels) if i not in disregarded_indices]

print(f"Number of valid texts: {len(texts)}")
print(f"Number of valid targets: {len(labels)}")
print(f"Indices of disregarded images: {disregarded_indices}")

# Split texts into words
words = [word for text in texts if isinstance(text, str) and text.strip() for word in normalize(text)] ## From ChatGPT

# Create vocabulary of unique words
vocabulary = sorted(set(words))
word_to_index = {word: idx for idx, word in enumerate(vocabulary)}  ### From ChatGPT

# Create 1-hot vectors for each word in the vocabulary
vocab_size = len(vocabulary)
print(vocab_size)
one_hot_vectors = np.eye(vocab_size)  # (identity matrix with size of vocab)

# Save to a matrix with word-to-one-hot mapping
one_hot_matrix = {word: one_hot_vectors[idx] for word, idx in word_to_index.items()}

# Verify
word = vocabulary[200]
one_hot_vector = one_hot_matrix[word]
print(f"One-hot vector for '{word}': {one_hot_vector}", "has shape", one_hot_vector.shape)

# Imports for Image Processing
from torchvision import transforms
from PIL import Image

# Define image transformations
image_transforms = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to 128x128
    transforms.ToTensor()          # Convert PIL Image to Tensor
])

images_dir = 'memotion-dataset/memotion_dataset_7k/images' 
image_files = sorted(os.listdir(images_dir)) #ensuring alignment with texts

image_paths = []
for path in image_files:
    pattern = r"_(\d+)\."
    idx = int(re.search(pattern, path).group(1))
    if idx - 1 not in disregarded_indices:
        image_paths.append(os.path.join(images_dir, path))
    else:
        print(os.path.join(images_dir, path))

## sort the image based on the number NOT string
image_paths = sorted(image_paths, key=lambda path: int(re.search(pattern, path).group(1)))
print(len(image_paths), len(texts))


6992
Number of valid texts: 6986
Number of valid targets: 6986
Indices of disregarded images: [119, 4799, 5118, 6781, 6784, 6786]
13035
One-hot vector for '207r': [0. 0. 0. ... 0. 0. 0.] has shape (13035,)
memotion-dataset/memotion_dataset_7k/images/image_120.jpg
memotion-dataset/memotion_dataset_7k/images/image_4800.jpg
memotion-dataset/memotion_dataset_7k/images/image_5119.png
memotion-dataset/memotion_dataset_7k/images/image_6782.jpg
memotion-dataset/memotion_dataset_7k/images/image_6785.jpg
memotion-dataset/memotion_dataset_7k/images/image_6787.jpg
6986 6986


In [5]:
# Some of the codes here are from Assignment 5

n = int(0.8 * len(labels)) # n is used to split train(80%)/test+val(20%)
m = int(0.9 * len(labels)) # m is used to split test(10%)/val(10%)

train_labels = labels[:n]
val_labels = labels[n:m]
test_labels = labels[m:]

train_texts = texts[:n]
val_texts = texts[n:m]
test_texts = texts[m:]


# Split image paths corresponding to text splits
train_image_paths = image_paths[:n]
val_image_paths = image_paths[n:m]
test_image_paths = image_paths[m:]

print(f"Train images: {len(train_image_paths)}, Validation images: {len(val_image_paths)}, Test images: {len(test_image_paths)}")

print(image_paths[99], texts[99], labels[99])


Train images: 5588, Validation images: 699, Test images: 699
memotion-dataset/memotion_dataset_7k/images/image_100.jpg Drink water you may not meme-generator.com 2


In [6]:
# transformer hyperparams (some taken from Assignment 5)
n_classes = 3 # Number of classes in the classification process
batch_size = 128 # 64 # batch size
block_size = 128 # context length
# block_size = 256 # context length
window_stride = 5 # 2 # number of words to skip in the sliding window when the text is > block_size
max_iters = 5000 # number of training epochs
eval_interval = 500 # set to evaluation mode to compute losses every `eval_interval` epochs
learning_rate = 3e-4 # learning rate for optimization
eval_iters = 200 # 200 # number of iteration when in evaluation mode
hotvector_dim = vocab_size # Hot vector size
n_embd = 384 # 768 # embedding dimension
n_head = 2 # 2 # number of heads
n_layer = 2 # 2 # number of layers
dropout = 0.2 # dropout rate
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)


cuda


In [7]:
class CnnModel(nn.Module):
    def __init__(self, feature_dim=256):
        super(CnnModel, self).__init__()
        self.feature_dim = feature_dim
        self.conv_layers = nn.Sequential(
            # Input is 3 * 224 * 224
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # Output: 32 x 112 x 112

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # Output: 64 x 56 x 56

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # Output: 128 x 28 x 28

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # Output: 256 x 14 x 14
        )
        self.fc = nn.Sequential(
            nn.Flatten(),                       # Flatten the tensor
            nn.Linear(256 * 14 * 14, 1024),       # Fully connected layer
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, n_embd),  # Final feature vector
            nn.ReLU()
        )

    def forward(self, x):
        x = self.conv_layers(x)  # Pass through convolutional layers
        x = self.fc(x)           # Pass through fully connected layers
        return x # Reduced to feature_dim; (B, feature_dim)


In [8]:
torch.manual_seed(1337)

def get_batch_with_images(split):
    # Fetch data based on split
    if split == 'train':
        data = train_texts
        targets = train_labels
        img_paths = train_image_paths
    elif split == 'val':
        data = val_texts
        targets = val_labels
        img_paths = val_image_paths
    elif split == 'test':
        data = test_texts
        targets = test_labels
        img_paths = test_image_paths

    # Generate random indices for batch
    ix = torch.randint(0, len(data), (batch_size,))

    x_text, x_images, y, masks, window_map = [], [], [], [], []
    current_window_id = 0
    n_batches = 0

    for i in ix:
        if n_batches >= batch_size:
            break
            
        input_text = data[i]
        target_label = targets[i]

        # Process text
        one_hot_input = torch.stack([torch.tensor(one_hot_matrix[word]) for word in normalize(input_text) if word in one_hot_matrix])

        if one_hot_input.shape[0] > block_size:
            # Split into sliding windows
            
#             print("Text size bigger than block_size was found")
            windows = sliding_window_split(input_text, block_size, window_stride)
            for window in windows:
                if n_batches >= batch_size:
                     break

                one_hot_window = torch.stack([torch.tensor(one_hot_matrix[word]) for word in window if word in one_hot_matrix])
                mask = torch.ones(block_size)

                # Pad if the window is smaller than block_size
                if one_hot_window.shape[0] < block_size:
                    padding = torch.zeros(block_size - one_hot_window.shape[0], vocab_size)
                    one_hot_window = torch.cat((one_hot_window, padding), dim=0)
                    mask = torch.cat((torch.ones(one_hot_window.shape[0]), torch.zeros(block_size - one_hot_window.shape[0])))

                x_text.append(one_hot_window)
                masks.append(mask)
                window_map.append(current_window_id)
                n_batches += 1
            
            y.append(target_label)
            current_window_id += 1

            # Process image
            image_path = img_paths[i]
            image = Image.open(image_path).convert('RGB')
            image = image_transforms(image).to(device)
            x_images.append(image)


        elif one_hot_input.shape[0] < block_size:
            padding = torch.zeros(block_size - one_hot_input.shape[0], vocab_size)
            one_hot_input = torch.cat((one_hot_input, padding), dim=0)
            mask = torch.cat((torch.ones(one_hot_input.shape[0]), torch.zeros(block_size - one_hot_input.shape[0])))
            x_text.append(one_hot_input)
            masks.append(mask)
            window_map.append(current_window_id)
            y.append(target_label)
            current_window_id += 1
            n_batches += 1
            
             # Process image
            image_path = img_paths[i]
            image = Image.open(image_path).convert('RGB')
            image = image_transforms(image).to(device)
            x_images.append(image)

            
        else:
            # Exact block size
            mask = torch.ones(block_size)
            x_text.append(one_hot_input)
            masks.append(mask)
            window_map.append(current_window_id)
            y.append(target_label)
            current_window_id += 1
            n_batches += 1
            
             # Process image
            image_path = img_paths[i]
            image = Image.open(image_path).convert('RGB')
            image = image_transforms(image).to(device)
            x_images.append(image)
            
    x_text = torch.stack(x_text).to(device, dtype=torch.float)  # (B, block_size, vocab_size)
    x_images = torch.stack(x_images).to(device, dtype=torch.float)  # (B, 3, 128, 128)
    y = torch.tensor(y, dtype=torch.long).to(device)
    masks = torch.stack(masks).to(device, dtype=torch.float)

    return x_text, x_images, y, masks, window_map


In [9]:
def sliding_window_split(input_text, block_size, window_stride):
    tokens = input_text.split()
    windows = []
    for start in range(0, len(tokens) - block_size + 1, window_stride):
        window = tokens[start:start + block_size]
        windows.append(window)
    # Add the last window if it's incomplete
    if len(tokens) % block_size != 0: ## This part is from ChatGPT
        windows.append(tokens[-block_size:])
    return windows


class Head(nn.Module):
    """ Single Head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        # create keys, queries, values using a single linear layer with input size = n_embd, and output size = head_size. Set bias=False!
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.dropout = nn.Dropout(dropout) # dropout layer

    ### Mask is used to remove only Padding while doing self attention
    def forward(self, x, mask):
        B,T,C = x.shape # recover the shape of input x
        # pass raw input x through the key and query layers
        k = self.key(x)   # (B, T, head_size)
        q = self.query(x) # (B, T, head_size)
        # perform self-attention between keys and queries
        wei = q @ k.transpose(-2, -1) # (B, T, head_size) @ (B, head_size, T) -> (B, T, T)
        # Mask paddings before doing self attention
        wei = wei.masked_fill(mask.unsqueeze(1) == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1) # (B, T, T) # normalization by softmax
        wei = self.dropout(wei) # sets a fraction (determined by the hyperparam dropout) of its elements to zero during training, thus providing a form of regularization
        # perform the weighted aggregation of the values
        v = self.value(x) # (B, T, head_size)
        out = wei @ v # (B, T, T) @ (B, T, head_size) -> (B, T, head_size)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        # Attention Mechanism (from Assignment 5)
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)]) # concatenating multiple heads in parallel in a module
        self.proj = nn.Linear(n_embd, n_embd) # projection layer
        self.dropout = nn.Dropout(dropout) # dropout layer

    def forward(self, x, masks):
        out = torch.cat([h(x, masks) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedFoward(nn.Module):

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: number of heads
        super().__init__()
        head_size = n_embd // n_head # n_embd = head_size * n_head
        self.sa = MultiHeadAttention(num_heads=n_head, head_size = head_size) # multi-Head attention component
        self.ffwd = FeedFoward(n_embd = n_embd) # simple feed forward neural net component
        self.ln1 = nn.LayerNorm(n_embd) # 1st layer normalization component
        self.ln2 = nn.LayerNorm(n_embd) # 2nd layer normalization component

    # forward pass with skip connections (self-attention followed by computation)
    def forward(self, x, masks):
        x = x + self.sa(self.ln1(x), masks)
        x = x + self.ffwd(self.ln2(x))
        return x


In [10]:
class Transformer(nn.Module):

    def __init__(self):
        super().__init__()
        # Projection Layer to find word embeddings and pos embedding
        self.embed_proj = nn.Linear(hotvector_dim, n_embd, bias=False)
        self.position_embedding = nn.Embedding(block_size, n_embd) # to avoid shape errors
        self.blocks = nn.ModuleList([Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        

    def forward(self, idx, masks, window_map, targets=None):
        B, T, _ = idx.shape

        # Process text
        pos_embeddings = self.position_embedding(torch.arange(T, device=device).unsqueeze(0))
        x_text = self.embed_proj(idx) + pos_embeddings
        for block in self.blocks:
            x_text = block(x_text, masks) # (B, T, n_embd) # pass x through transformer blocks
        x_text = self.ln_f(x_text) # (B, T, n_embd) # pass x through final layer norm

        # Aggregate text features based on window_map
        unique_window_ids = torch.unique(torch.tensor(window_map))
        aggregated_text = []
        for window_id in unique_window_ids:
            indices = [i for i, w in enumerate(window_map) if w == window_id]
            window_outputs = x_text[indices]
            aggregated_text.append(window_outputs.mean(dim=0).mean(dim=0))   # Mean pooling for each text

        x_text = torch.stack(aggregated_text)  # (Num_Texts, n_embd)
        return x_text


In [11]:
class ClassificationModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.image_encoder = CnnModel(feature_dim = 256).to(device)
        self.text_encoder = Transformer().to(device)
        self.alpha_txt = nn.Parameter(torch.tensor(0.5))
        self.alpha_img = nn.Parameter(torch.tensor(0.5))
        self.classifier_head = nn.Sequential(
           nn.Linear(n_embd, 512), 
           nn.LeakyReLU(),
           nn.Dropout(dropout),
           nn.Linear(512, 128),
           nn.LeakyReLU(),
           nn.Dropout(dropout),
           nn.Linear(128, n_classes),
        )
        

    def forward(self, idx, images, masks, window_map, targets=None):
        # Extract image features
        img_out = self.image_encoder(images)
        # Extract text features
        txt_out = self.text_encoder(idx, masks, window_map)
        
        # Weighted combination
        wt_emb = self.alpha_txt * txt_out + self.alpha_img * img_out

        logits = self.classifier_head(wt_emb)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits, targets)
        return logits, loss



In [12]:
# Function to evaluate the model
@torch.no_grad()
def evaluate_model():
    model.eval()  # Set the model to evaluation mode
    total_correct = 0
    total_samples = 0

    for _ in range(eval_iters):  # Evaluate for a fixed number of iterations
        xb, xb_images, yb, masks, window_map = get_batch_with_images('test') # Fetch test batch
        logits, _ = model(xb, xb_images, masks, window_map)  # Get model predictions

        # Compute predicted classes
        predictions = torch.argmax(logits, dim=1)  # Shape: (batch_size,)

        # Compare predictions with ground truth
        total_correct += (predictions == yb).sum().item()
        total_samples += yb.size(0)

    # Compute accuracy
    accuracy = total_correct / total_samples
    print(f"Test Accuracy: {accuracy * 100:.2f}%")
    return accuracy


def train_accuracy():
    model.eval()  # Set the model to evaluation mode
    total_correct = 0
    total_samples = 0

    for _ in range(eval_iters):  # Evaluate for a fixed number of iterations
        xb, xb_images, yb, masks, window_map = get_batch_with_images('train') # Fetch train batch
        logits, _ = model(xb, xb_images, masks, window_map)  # Get model predictions

        # Compute predicted classes
        predictions = torch.argmax(logits, dim=1)  # Shape: (batch_size,)

        # Compare predictions with ground truth
        total_correct += (predictions == yb).sum().item()
        total_samples += yb.size(0)

    # Compute accuracy
    accuracy = total_correct / total_samples
    print(f"Train Accuracy: {accuracy * 100:.2f}%")
    return accuracy


def validation_accuracy():
    model.eval()  # Set the model to evaluation mode
    total_correct = 0
    total_samples = 0

    for _ in range(eval_iters):  # Evaluate for a fixed number of iterations
        xb, xb_images, yb, masks, window_map = get_batch_with_images('val') # Fetch train batch
        logits, _ = model(xb, xb_images, masks, window_map)  # Get model predictions

        # Compute predicted classes
        predictions = torch.argmax(logits, dim=1)  # Shape: (batch_size,)

        # Compare predictions with ground truth
        total_correct += (predictions == yb).sum().item()
        total_samples += yb.size(0)

    # Compute accuracy
    accuracy = total_correct / total_samples
#     print(f"Validation Accuracy: {accuracy * 100:.2f}%")
    return accuracy

def test_accuracy():
    model.eval()  # Set the model to evaluation mode
    total_correct = 0
    total_samples = 0

    for _ in range(eval_iters):  # Evaluate for a fixed number of iterations
        xb, xb_images, yb, masks, window_map = get_batch_with_images('test') # Fetch train batch
        logits, _ = model(xb, xb_images, masks, window_map)  # Get model predictions

        # Compute predicted classes
        predictions = torch.argmax(logits, dim=1)  # Shape: (batch_size,)

        # Compare predictions with ground truth
        total_correct += (predictions == yb).sum().item()
        total_samples += yb.size(0)

    # Compute accuracy
    accuracy = total_correct / total_samples
#     print(f"Test Accuracy: {accuracy * 100:.2f}%")
    return accuracy


In [None]:
# Initialize the model
model = ClassificationModel()
model = model.to(device)
print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

best_val_accuracy = 0.0
patience = 3
cnt = 0
for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % (eval_interval) == 0 or iter == max_iters - 1:
        losses = {}
        model.eval()
        for split in ['train', 'val', 'test']:
            total_loss = 0.0
            for _ in range(eval_iters):
                xb, xb_images, yb, masks, window_map = get_batch_with_images(split)
                logits, loss = model(xb, xb_images, masks, window_map, yb)
                total_loss += loss.item()
            avg_loss = total_loss / eval_iters
            losses[split] = avg_loss
        
        val_accuracy = validation_accuracy()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, test loss {losses['test']:.4f}, val accuracy {val_accuracy:.4f}, test accuracy: {test_accuracy()}")
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            cnt = 0
            torch.save(model.state_dict(), 'best_model.pth') # saving model weights
        else:
            cnt += 1
        
        if cnt >= patience:
            print(f"Early stop at {iter} iterations")
            break
        model.train()

    # sample a batch of data
    xb, xb_images, yb, masks, window_map = get_batch_with_images('train')

    # Forward pass
    logits, loss = model(xb, xb_images, masks, window_map, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    

# Evaluate the model
evaluate_model()
train_accuracy()

61.028421 M parameters
