In [None]:
## imports
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import pandas as pd 
import numpy as np
import random
from tqdm import tqdm
import timeit

In [None]:
##CUDA settings
random.seed(0)
torch.manual_seed(0)
np.random.seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
## Model parameters
lr = 1e-4
num_classes= 10
patch_size = 4
in_channel = 1
num_heads = 8
dropout = 0.001
hidden_dim = 768
activation_fun = "gelu"
num_encoders = 4
image_size = 28
embd_dim = (patch_size**2) * in_channel
num_patches = (image_size // patch_size) ** 2

In [None]:
## Device to use
device = "cuda" if torch.cuda.is_available() else "mps"

In [None]:
## ViT (Vision trasfromer class)
class ImagePatcher(nn.Module):
    def __init__(self, patch_size, in_channel, embd_dim,num_patches,image_size):
        super(ImagePatcher, self).__init__()
        self.image_size = image_size
        self.num_patches = num_patches
        self.proj = nn.Conv2d(in_channel, embd_dim, kernel_size=patch_size, stride=patch_size)
        
        
        self.cls_token = nn.Parameter(torch.randn(size=(1, in_channel, embd_dim)),requires_grad=True)
        self.pos_embed = nn.Parameter(torch.randn(size=(1, self.num_patches + 1, embd_dim)),requires_grad=True)
        
        self.dropout = nn.Dropout(0.001)
        
    def forward(self, x):
        B, C, H, W = x.shape
        assert H == W == self.image_size, f"Input image size ({H}*{W}) doesn't match model ({self.img_size}*{self.img_size})."
        

        x = self.proj(x)  # (B, embd_dim, H', W')
        x = x.flatten(2).transpose(1, 2)  # (B, num_patches, embd_dim)
        

        cls_token = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_token, x], dim=1)
        x= self.dropout(x)
        

        x = x + self.pos_embed
        
        return x

In [None]:
class ViT(nn.Module):
    def __init__(self, num_patches, image_size, num_classes, patch_size, embd_dim, num_encoders, num_heads, hidden_dim, activation_fun, in_channel):
        super().__init__()
        
        
        self.embeddings_block = ImagePatcher(patch_size, in_channel, embd_dim,num_patches,image_size)
        
        encoder_layer = nn.TransformerEncoderLayer(d_model=embd_dim, nhead=num_heads, dropout=dropout, activation=activation_fun, batch_first=True, norm_first=True)
        self.encoder_blocks = nn.TransformerEncoder(encoder_layer, num_layers=num_encoders)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(normalized_shape=embd_dim),
            nn.Linear(in_features=embd_dim, out_features=num_classes)
        )

    def forward(self, x):
        x = self.embeddings_block(x)
        x = self.encoder_blocks(x)
        x = self.mlp_head(x[:, 0, :])  
        return x
model = ViT(num_patches, image_size, num_classes, patch_size, embd_dim, num_encoders, num_heads, hidden_dim, activation_fun, in_channel).to(device)# BATCH_SIZE X NUM_CLASSES

In [None]:
train_df = pd.read_csv('./train.csv')
test_df = pd.read_csv('./test.csv')
train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=0, shuffle=True)

In [None]:
train_df

In [None]:
class Traindataset(Dataset):
    def __init__(self, image, labels,indicies):
        self.image = image
        self.labels = labels
        self.indicies = indicies
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomRotation(15),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])
    def __len__(self):
        return len(self.image)
    
    def __getitem__(self, idx):
        img = self.image[idx].reshape((28, 28)).astype(np.uint8)
        label = self.labels[idx]
        index = self.indicies[idx]
        img = self.transform(img)
        
        return {"image": img, "label": label, "index": index}
    
class Valdataset(Dataset):
    def __init__(self, image, labels,indicies):
        self.image = image
        self.labels = labels
        self.indicies = indicies
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])
    def __len__(self):
        return len(self.image)
    
    def __getitem__(self, idx):
        img = self.image[idx].reshape((28, 28)).astype(np.uint8)
        label = self.labels[idx]
        index = self.indicies[idx]
        img = self.transform(img)
        
        return {"image": img, "label": label, "index": index}

class Testdataset(Dataset):
    def __init__(self, image,indicies):
        self.image = image
        self.indicies = indicies
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])
    def __len__(self):
        return len(self.image)
    
    def __getitem__(self, idx):
        img = self.image[idx].reshape((28, 28)).astype(np.uint8)
        index = self.indicies[idx]
        img = self.transform(img)
        
        return {"image": img, "index": index}
        

In [None]:
train_dataset = Traindataset(train_df.iloc[:, 1:].values.astype(np.uint8), train_df.iloc[:, 0].values, train_df.index.values)
val_dataset = Valdataset(val_df.iloc[:, 1:].values.astype(np.uint8), val_df.iloc[:, 0].values, val_df.index.values)
test_dataset = Testdataset(test_df.values.astype(np.uint8), test_df.index.values)

In [None]:
train_dataloader = DataLoader(dataset=train_dataset,
                              batch_size=512,
                              shuffle=True)

val_dataloader = DataLoader(dataset=val_dataset,
                            batch_size=512,
                            shuffle=True)

test_dataloader = DataLoader(dataset=test_dataset,
                             batch_size=512,
                             shuffle=False)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), betas=(0.9,0.999), lr=0.001, weight_decay=0)

In [None]:
EPOCH = 2

In [None]:
start = timeit.default_timer()
for epoch in tqdm(range(EPOCH), position=0, leave=True):
    model.train()
    train_loss = 0
    for idx, data in enumerate(tqdm(train_dataloader, position=0, leave=True)):
        img = data['image'].float().to(device)
        labl = data['label'].type(torch.uint8).to(device)
        
        y_pred = model(img)
        
        loss = criterion(y_pred, labl)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
    train_loss = train_loss / (idx + 1)
    print(f"Epoch: {epoch} ; Training loss: {train_loss}")
    print()

In [None]:
model.eval()
val_loss = 0
with torch.no_grad():
    for idx, data in enumerate(tqdm(val_dataloader, position=0, leave=True)):
        img = data["image"].float().to(device)
        labl = data["label"].type(torch.uint8).to(device)
        y_pred = model(img)
        loss = criterion(y_pred, labl)
        val_loss += loss.item()
    val_loss = val_loss / (idx + 1)
print(f"Validation loss: {val_loss}")