# Official model

In [1]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from aiohttp._websocket import mask

# Check if GPU is available
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")
print(f"GPU count: {torch.cuda.device_count()}")
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = 'cpu'
print(f"Using device: {device}")

# import os


PyTorch version: 2.6.0+cpu
CUDA available: False
CUDA version: None
GPU count: 0
Using device: cpu


# Load dataset

In [2]:
from datasets import load_dataset

ds = load_dataset("uoft-cs/cifar10", split='train')
print(ds)

Dataset({
    features: ['img', 'label'],
    num_rows: 50000
})


# Prepare data

In [3]:

from torchvision import transforms

labels_map = {0: "airplane", 1: "automobile", 2: "bird", 3: "cat", 4: "deer", 5: "dog", 6: "frog", 7: "horse",
              8: "ship", 9: "truck"}

transform = transforms.Compose([
    transforms.PILToTensor()
])
images = []
labels = []
for i, data in enumerate(ds):
    img_tensor = transform(data['img'])
    images.append(img_tensor)
    labels.append(data['label'])

    if i == 500:
        break
# (C, H, W) --> channels, height, width. Channels are: Red, Green, Blue (each pixel can have <0-255) values in each channel. They together (mixed) provide what we see as image)
PATCH_SIZE = 4
print('Patch Size:', PATCH_SIZE)
img: torch.Tensor = images[0]
changed = img.unfold(dimension=1, size=PATCH_SIZE,
                     step=PATCH_SIZE)  # unfold on dimension 1 so Height and we get: (Channels, New Height, Width, Patch Height)
changed = changed.unfold(dimension=2, size=PATCH_SIZE,
                         step=PATCH_SIZE)  # unfold on dimension 2 so Width and we get: (Channels, New Height, New Width, Patch Height, Patch Width)
# Finally we get 3, 8, 8, 4, 4 --> channels, number of patches along height, number of patches along width, patch height, patch width

# ch - channel, nph - number of patches along height, npw - number of patches along width, ph - patch height, pw - patch width
for i in range(len(images)):
    new_image: torch.Tensor = images[i].unfold(1, PATCH_SIZE, PATCH_SIZE).unfold(2, PATCH_SIZE,
                                                                                 PATCH_SIZE)  # ch, nph, npw, ph, pw
    new_image = new_image.swapdims(0, 1).swapdims(1, 2)  # nph, npw, ch, ph, pw
    new_image = new_image.reshape(new_image.size(0) * new_image.size(1), new_image.size(2) * new_image.size(3) *
                                  new_image.size(4))  # number of patches, all features (C * PH * PW)
    images[i] = new_image

Patch Size: 4


# Custom Dataset and DataLoader with train and val data (80%, 20%)

In [4]:
from torch.utils.data import random_split, Dataset


class CustomDataset(Dataset):
    def __init__(self, inputs: torch.Tensor, outputs: torch.Tensor):
        self.inputs = inputs
        self.outputs = outputs

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.outputs[idx]

    def __setitem__(self, idx, value):
        self.inputs[idx] = value


dataset = CustomDataset(images, labels)
train_data, val_data = random_split(dataset, [0.8, 0.2])
print(f"Number of training samples: {len(train_data)}")
print(f"Number of validation samples: {len(val_data)}")

Number of training samples: 401
Number of validation samples: 100


In [5]:
from torch.utils.data import DataLoader

BATCH_SIZE = 16

train_loader = DataLoader(
    train_data,
    batch_size=BATCH_SIZE,
    shuffle=True
)

val_loader = DataLoader(
    val_data,
    batch_size=BATCH_SIZE,
    shuffle=False
)

print(f"Number of batches in training set: {len(train_loader)}")
print(f"Number of batches in validation set: {len(val_loader)}")

Number of batches in training set: 26
Number of batches in validation set: 7


In [6]:
class PatchEmbedding(nn.Module):
    def __init__(self, patch_dim, embed_dim):
        super(PatchEmbedding, self).__init__()
        self.projection = nn.Linear(patch_dim, embed_dim)

    def forward(self, x):
        x = x.float()
        return self.projection(x)


In [7]:
class CLSTokenizer(nn.Module):
    def __init__(self, embed_dim):
        super(CLSTokenizer, self).__init__()
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

    def forward(self, x):
        batch_size = x.size(0)
        cls_token = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        return x


In [8]:
import math


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        # Ensure that the model dimension (d_model) is divisible by the number of heads
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        # Initialize dimensions
        self.d_model = d_model  # Model's dimension
        self.num_heads = num_heads  # Number of attention heads
        self.d_k = d_model // num_heads  # Dimension of each head's key, query, and value

        # Linear layers for transforming inputs
        self.W_q = nn.Linear(d_model, d_model)  # Query transformation
        self.W_k = nn.Linear(d_model, d_model)  # Key transformation
        self.W_v = nn.Linear(d_model, d_model)  # Value transformation
        self.W_o = nn.Linear(d_model, d_model)  # Output transformation

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Calculate attention scores
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        # Apply mask if provided (useful for preventing attention to certain parts like padding)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)

        # Softmax is applied to obtain attention probabilities
        attn_probs = torch.softmax(attn_scores, dim=-1)

        # Multiply by values to obtain the final output
        output = torch.matmul(attn_probs, V)
        return output

    def split_heads(self, x):
        # Reshape the input to have num_heads for multi-head attention
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)

    def combine_heads(self, x):
        # Combine the multiple heads back to original shape
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)

    def forward(self, Q, K, V, mask=None):
        # Apply linear transformations and split heads
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))

        # Perform scaled dot-product attention
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)

        # Combine heads and apply output transformation
        output = self.W_o(self.combine_heads(attn_output))
        return output

In [9]:
class MLP(nn.Module):
    def __init__(self, d_model, d_ff):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.gelu = nn.GELU(approximate='tanh')

    def forward(self, x):
        return self.fc2(self.gelu(self.fc1(x)))

In [10]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.mlp = MLP(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        residual_x = self.norm1(x)
        attn_output = self.self_attn(residual_x, residual_x, residual_x, mask)
        x = x + self.dropout(attn_output)
        residual_x = self.norm2(x)
        mlp_output = self.mlp(residual_x)
        x = x + self.dropout(mlp_output)
        return x

In [11]:
class ViTModel(nn.Module):
    def __init__(self, patch_dim, embed_dim, seq_len, num_heads, d_mlp, dropout, num_layers, num_classes):
        super(ViTModel, self).__init__()
        self.patch_embedding = PatchEmbedding(patch_dim, embed_dim)
        self.cls = CLSTokenizer(embed_dim)
        self.poss_embedding = nn.Embedding(seq_len, embed_dim)
        self.encoder_layers = nn.ModuleList(
            [EncoderLayer(embed_dim, num_heads, d_mlp, dropout) for _ in range(num_layers)])
        self.fc = nn.Linear(embed_dim, num_classes)

    # def generate_mask(self, src, tgt):
    #     src_mask = (src != tokenizer.pad_token_id).unsqueeze(1).unsqueeze(2)
    #     tgt_mask = (tgt != tokenizer.pad_token_id).unsqueeze(1).unsqueeze(3)
    #     src_mask, tgt_mask = src_mask.to(device), tgt_mask.to(device)
    #     seq_length = tgt.size(1)
    #     casual_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
    #     casual_mask = casual_mask.to(device)
    #     tgt_mask = tgt_mask & casual_mask
    #     return src_mask, tgt_mask

    def forward(self, x, mask=None):
        x = x.float()
        # 1. Patch embeddings
        x = self.patch_embedding(x)
        # 2. Add class token at the beginning
        x = self.cls(x)
        # 3. Add positional embeddings to patch embeddings (including class token)
        pos_embedding = self.poss_embedding(torch.arange(0, x[0].size(0))).to(device)
        x = x + pos_embedding
        # 4. Feed encoder layers
        for encoder_layer in self.encoder_layers:
            x = encoder_layer(x, mask)
        cls_output = x[:, 0]
        output = self.fc(cls_output)
        return output


In [12]:
EMBEDDING_DIM = 128
PATCH_DIM = images[0].size(-1)
SEQ_LEN = images[0].size(0) + 1
NUM_HEADS = 8
NUM_LAYERS = 6
DIM_MLP = 3072
DROPOUT = 0.1
NUM_CLASSES = len(labels_map)

In [13]:
print('Embedding dim:', EMBEDDING_DIM)
print('Patch Dim:', PATCH_DIM)

vitModel = ViTModel(PATCH_DIM, EMBEDDING_DIM, SEQ_LEN, NUM_HEADS, DIM_MLP, DROPOUT, NUM_LAYERS, NUM_CLASSES)
my_model = vitModel.to(device)


Embedding dim: 128
Patch Dim: 48


# Train model

In [None]:
torch.manual_seed(42)
from transformers import get_linear_schedule_with_warmup
from torch import optim
from tqdm import tqdm

lr = 3e-4
optimizer = optim.AdamW(vitModel.parameters(), lr=lr, betas=(0.9, 0.95), eps=1e-8)

num_epochs = 10
num_training_steps = len(train_loader) * num_epochs
num_warmup_steps = int(0.1 * num_training_steps)

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps
)

vitModel.train()
lr_history = []
loss_history = []
for epoch in range(num_epochs):
    optimizer.zero_grad()
    epoch_loss = 0.0

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}")

    for src_data, tgt_data in progress_bar:
        src_data = src_data.to(device)
        tgt_data = tgt_data.to(device)

        output = vitModel(src_data)

        loss = F.cross_entropy(output, tgt_data)

        lr_history.append(optimizer.param_groups[0]['lr'])
        loss_history.append(loss.item())
        epoch_loss += loss.item()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(vitModel.parameters(), 1.0)  # Prevent exploding gradients
        optimizer.step()  # Update parameters

        optimizer.zero_grad()
        scheduler.step()

        # Update progress bar with current loss
        progress_bar.set_postfix({"Batch Loss": loss.item()})
    print(f"Epoch: {epoch + 1}, Loss: {epoch_loss / len(train_loader):.4f}")

Epoch 1: 100%|██████████| 26/26 [00:08<00:00,  3.17it/s, Batch Loss=2.75]


Epoch: 1, Loss: 2.5539


Epoch 2:  31%|███       | 8/26 [00:02<00:05,  3.08it/s, Batch Loss=2.39]

# Show charts with lr and loss

In [None]:
# from matplotlib import pyplot as plt
# 
# # Ensure lr_history and loss_history are lists of equal length
# assert len(lr_history) == len(loss_history), "Length of lr_history and loss_history must be the same"
# 
# fig, ax1 = plt.subplots()
# 
# # Plot Loss on primary y-axis
# ax1.set_title("Learning Rate vs. Loss")
# ax1.set_xlabel("Training Step")
# ax1.set_ylabel("Loss", color='tab:red')
# ax1.plot(range(len(loss_history)), loss_history, color='tab:red', label='Loss')
# ax1.tick_params(axis='y', labelcolor='tab:red')
# 
# # Plot Learning Rate on secondary y-axis
# ax2 = ax1.twinx()
# ax2.set_ylabel("Learning Rate", color='tab:blue')
# ax2.plot(range(len(lr_history)), lr_history, color='tab:blue', linestyle='--', label='Learning Rate')
# ax2.tick_params(axis='y', labelcolor='tab:blue')
# 
# # Legends
# ax1.legend(loc='upper left')
# ax2.legend(loc='upper right')
# 
# plt.show()


# Calculate loss on validation data

In [None]:
vitModel.eval()

total_val_loss = 0.0

with torch.no_grad():
    # Use tqdm for progress bar
    progress_bar = tqdm(val_loader, desc="Validating")

    for val_src_data, val_tgt_data in progress_bar:
        # Move data to GPU
        val_src_data, val_tgt_data = val_src_data.to(device), val_tgt_data.to(device)

        # Forward pass
        val_output = vitModel(val_src_data)

        # Calculate loss
        val_loss = F.cross_entropy(
            val_output, val_tgt_data
        )

        total_val_loss += val_loss.item()

        # Update progress bar with current batch loss
        progress_bar.set_postfix({"Batch Loss": val_loss.item()})

# Calculate average validation loss
avg_val_loss = total_val_loss / len(val_loader)
print(f"Average Validation Loss: {avg_val_loss:.4f}")

# Model inference

In [None]:
def classify_image(image, model):
    model.eval()  # Set the model to evaluation mode

    src_data = image.unsqueeze(0)

    with torch.no_grad():
        # Generate output (prediction for next token)
        output = model(src_data)
        print('Output is:', output.shape)

        # Get the last token's logits and find the token with the highest probability
        image_class = output.argmax(dim=-1).item()
        image_class_name = labels_map[image_class]

        return image_class_name


def inference_from_datasets(train_dataset: bool = True, index: int = 0):
    if train_dataset:
        dataset = train_loader.dataset
    else:
        dataset = val_loader.dataset
    image = dataset[index][0]
    print('Image is:', image.shape)
    image_class_name = classify_image(image, model=vitModel)

    print('Dataset:', 'Train' if train_dataset else 'Validation')
    print('Generated class:', image_class_name)
    print('Real translation:', labels_map[dataset[index][1]])


inference_from_datasets(train_dataset=True, index=0)

In [None]:
# PATH = r"my_model_translation.pt"
# torch.save(transformer.state_dict(), PATH)

In [None]:
# next_model = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_src_seq_len,
#                          max_tgt_seq_len, dropout)
# next_model.load_state_dict(torch.load(PATH, weights_only=True))
# next_model = next_model.to(device)
# # print(next_model)
#
# # sentence = tokenizer.decode(train_loader.dataset[0][0].tolist(), skip_special_tokens=True)
# sentence = "What are light beans there?"
# print(sentence)
# # sentence = "Prehistoric humans studied the relationship between the seasons and the length of days to plan their hunting and gathering activities."
# translation = translate_sentence(sentence, tokenizer, next_model)
# print(translation)