In [1]:
from datasets import load_dataset
import pandas as pd

from tqdm import tqdm
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset

torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu")
img_to_tensor = transforms.ToTensor()
device = torch.get_default_device()
generator = torch.Generator(device)

img_size = 32
patch_size = 4
n_patches = (img_size// patch_size)**2
n_heads = 4
trf_blocks = 4

n_classes = 10
embed_dim = 64
batch_size = 100

print(device)

cuda:0


In [2]:
class ImageDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return img_to_tensor(self.data[idx]["img"]).to(device), torch.tensor(self.data[idx]["label"])
    
ds = load_dataset("uoft-cs/cifar10")

train_dl = DataLoader(ImageDataset(ds["train"]),batch_size=batch_size, shuffle=True, drop_last=True, generator=generator)
test_dl = DataLoader(ImageDataset(ds["test"]), batch_size=batch_size, shuffle=False, generator=generator)

In [3]:
class PatchEmbeddings(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed = nn.Conv2d(3, embed_dim, kernel_size=(patch_size, patch_size), stride= patch_size)
        self.enrich = nn.Conv2d(embed_dim, embed_dim, kernel_size=(3, 3), stride= 1, padding=1)

    def forward(self,x):
        x = self.embed(x) # batch_size x color_channel x img_size/patch_size x img_size/patch_size
        x = self.enrich(x) # batch_size x embed_dims x img_size/patch_size x img_size/patch_size
        x = x.flatten(2).transpose(1,2) # batch_size x n_patches x embed_dim
        return x
    
class TransformerEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

        self.attention = nn.MultiheadAttention(embed_dim, n_heads, batch_first=True)
        
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim*4),
            nn.GELU(),
            nn.Linear(embed_dim*4,embed_dim)
        )

    def forward(self, x):
        res1 = x
        x = self.norm1(x)
        x = self.attention(x, x, x)[0]
        x += res1

        res2 = x
        x = self.norm2(x)
        x = self.mlp(x)
        x += res2
        return x
    
class VisionTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.patch_embedding = PatchEmbeddings()
        self.cls_token = nn.Parameter(torch.randn(1,1,embed_dim)) # 1 x 1 x embed_dim
        self.pos_embed = nn.Parameter(torch.randn(1, n_patches + 1, embed_dim)) # 1 x numpatches+1 x embed_dim
        self.transformer_layers = nn.Sequential(*[TransformerEncoder() for _ in range(trf_blocks)]) # batch_size x n_patches x embed_dim

        self.out_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, 10)
        )

    def forward(self,x):
        # batch_size x num_channels x img_size x img_size
        x = self.patch_embedding(x) # batch_size x num_patches x embed_dim
        B = len(x)

        cls_tokens = self.cls_token.expand(B , -1, -1) # batch_size x 1 x embed_dim

        x = torch.cat((cls_tokens, x), dim=1) # batch_size x (numpatches+1) x emebed_dim
        x = x + self.pos_embed #  batch_size x (numpatches+1) x emebed_dim
        x = self.transformer_layers(x) #  batch_size x (numpatches+1) x emebed_dim
        # getting output from first token
        x = x[:,0]  # batchsize x embed_dim
        x = self.out_head(x) # batchsize x n_classes
        return x

In [4]:
epochs = 50   
model = VisionTransformer()
opt = torch.optim.AdamW(model.parameters())
loss_fn = nn.CrossEntropyLoss()
history = []

total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")

Total parameters: 245,002


In [5]:
for epoch in range(epochs):
    total_loss = 0
    correct = 0

    for input_batch, label_batch in tqdm(train_dl):
        opt.zero_grad()

        pred_batch = model(input_batch)
        loss = loss_fn(pred_batch, label_batch)
        loss.backward()

        opt.step()

        with torch.no_grad():
            total_loss += loss.item()
            for i,label in enumerate(label_batch):
                if pred_batch[i,label.item()] == pred_batch[i].max():
                    correct+=1

    
    train_loss = total_loss/ len(train_dl)
    train_acc = correct / (len(train_dl) * batch_size)
    
    model.eval()
    test_loss, test_correct, m = 0, 0, 0

    with torch.no_grad():
        for input_batch, label_batch in test_dl:
            logits = model(input_batch)
            loss = loss_fn(logits, label_batch)

            test_loss += loss.item() * input_batch.size(0)
            preds = logits.argmax(dim=1)
            test_correct += (preds == label_batch).sum().item()
            m += input_batch.size(0)

    test_loss /= m
    test_acc = test_correct / m

    metrics = {
        'epoch': epoch + 1,
        'train_loss': train_loss,
        'train_acc': train_acc,
        'test_loss': test_loss,
        'test_acc': test_acc
    }
    print(metrics, "\n\n")
    # ---- Log metrics ----
    history.append(metrics)

    history_df = pd.DataFrame(history)
    history_df.to_csv("./history_convt.csv", index=False)


    torch.save(model.state_dict(), "./convt_model.pth")

100%|██████████| 500/500 [00:29<00:00, 17.22it/s]


{'epoch': 1, 'train_loss': 1.8065483801364899, 'train_acc': 0.3389, 'test_loss': 1.5417100512981414, 'test_acc': 0.4374} 




100%|██████████| 500/500 [00:25<00:00, 19.43it/s]


{'epoch': 2, 'train_loss': 1.436655358314514, 'train_acc': 0.48034, 'test_loss': 1.395037395954132, 'test_acc': 0.4974} 




100%|██████████| 500/500 [00:25<00:00, 19.41it/s]


{'epoch': 3, 'train_loss': 1.3042364282608032, 'train_acc': 0.5313, 'test_loss': 1.3064229583740234, 'test_acc': 0.5332} 




100%|██████████| 500/500 [00:25<00:00, 19.43it/s]


{'epoch': 4, 'train_loss': 1.2120930855274201, 'train_acc': 0.56212, 'test_loss': 1.254409909248352, 'test_acc': 0.5465} 




100%|██████████| 500/500 [00:25<00:00, 19.29it/s]


{'epoch': 5, 'train_loss': 1.138690283536911, 'train_acc': 0.59142, 'test_loss': 1.2067651224136353, 'test_acc': 0.5666} 




100%|██████████| 500/500 [00:26<00:00, 19.19it/s]


{'epoch': 6, 'train_loss': 1.0756187553405763, 'train_acc': 0.6141, 'test_loss': 1.1896785807609558, 'test_acc': 0.5736} 




100%|██████████| 500/500 [00:26<00:00, 18.85it/s]


{'epoch': 7, 'train_loss': 1.0134171195030213, 'train_acc': 0.63828, 'test_loss': 1.1523919028043748, 'test_acc': 0.5868} 




100%|██████████| 500/500 [00:26<00:00, 18.99it/s]


{'epoch': 8, 'train_loss': 0.9559067809581756, 'train_acc': 0.6565, 'test_loss': 1.1661545574665069, 'test_acc': 0.5905} 




100%|██████████| 500/500 [00:26<00:00, 18.94it/s]


{'epoch': 9, 'train_loss': 0.8981393092870712, 'train_acc': 0.67848, 'test_loss': 1.1648722612857818, 'test_acc': 0.592} 




100%|██████████| 500/500 [00:26<00:00, 18.78it/s]


{'epoch': 10, 'train_loss': 0.840748625934124, 'train_acc': 0.70022, 'test_loss': 1.1632417780160904, 'test_acc': 0.599} 




100%|██████████| 500/500 [00:26<00:00, 18.96it/s]


{'epoch': 11, 'train_loss': 0.7825576213598251, 'train_acc': 0.7216, 'test_loss': 1.165703343153, 'test_acc': 0.5971} 




100%|██████████| 500/500 [00:26<00:00, 19.15it/s]


{'epoch': 12, 'train_loss': 0.7232865092158317, 'train_acc': 0.743, 'test_loss': 1.1979321897029878, 'test_acc': 0.603} 




100%|██████████| 500/500 [00:25<00:00, 19.37it/s]


{'epoch': 13, 'train_loss': 0.6636950560212135, 'train_acc': 0.76286, 'test_loss': 1.252176753282547, 'test_acc': 0.5981} 




100%|██████████| 500/500 [00:26<00:00, 19.08it/s]


{'epoch': 14, 'train_loss': 0.610406551361084, 'train_acc': 0.78176, 'test_loss': 1.2999709409475326, 'test_acc': 0.5921} 




100%|██████████| 500/500 [00:26<00:00, 19.20it/s]


{'epoch': 15, 'train_loss': 0.5549638003110886, 'train_acc': 0.80284, 'test_loss': 1.3580895614624025, 'test_acc': 0.5917} 




100%|██████████| 500/500 [00:24<00:00, 20.54it/s]


{'epoch': 16, 'train_loss': 0.5139704847335815, 'train_acc': 0.81656, 'test_loss': 1.354179652929306, 'test_acc': 0.5952} 




100%|██████████| 500/500 [00:25<00:00, 19.64it/s]


{'epoch': 17, 'train_loss': 0.4662900579571724, 'train_acc': 0.83332, 'test_loss': 1.4368677270412444, 'test_acc': 0.5977} 




100%|██████████| 500/500 [00:25<00:00, 19.78it/s]


{'epoch': 18, 'train_loss': 0.4164794645309448, 'train_acc': 0.85094, 'test_loss': 1.5084602773189544, 'test_acc': 0.5882} 




100%|██████████| 500/500 [00:26<00:00, 18.88it/s]


{'epoch': 19, 'train_loss': 0.387024217158556, 'train_acc': 0.86254, 'test_loss': 1.5560250318050384, 'test_acc': 0.5925} 




100%|██████████| 500/500 [00:26<00:00, 18.71it/s]


{'epoch': 20, 'train_loss': 0.34945813989639285, 'train_acc': 0.87594, 'test_loss': 1.6153416192531587, 'test_acc': 0.5886} 




100%|██████████| 500/500 [00:26<00:00, 18.57it/s]


{'epoch': 21, 'train_loss': 0.3246192290186882, 'train_acc': 0.88294, 'test_loss': 1.7089925372600556, 'test_acc': 0.5898} 




100%|██████████| 500/500 [00:27<00:00, 18.38it/s]


{'epoch': 22, 'train_loss': 0.2938841501474381, 'train_acc': 0.89366, 'test_loss': 1.7349349796772002, 'test_acc': 0.5962} 




100%|██████████| 500/500 [00:27<00:00, 18.34it/s]


{'epoch': 23, 'train_loss': 0.27879967604577544, 'train_acc': 0.90044, 'test_loss': 1.768889307975769, 'test_acc': 0.5896} 




100%|██████████| 500/500 [00:26<00:00, 18.60it/s]


{'epoch': 24, 'train_loss': 0.2559989265650511, 'train_acc': 0.9077, 'test_loss': 1.7510465216636657, 'test_acc': 0.5875} 




100%|██████████| 500/500 [00:26<00:00, 18.65it/s]


{'epoch': 25, 'train_loss': 0.24727791196107865, 'train_acc': 0.91238, 'test_loss': 1.8604367530345918, 'test_acc': 0.5859} 




100%|██████████| 500/500 [00:26<00:00, 19.01it/s]


{'epoch': 26, 'train_loss': 0.22074236719310283, 'train_acc': 0.9216, 'test_loss': 1.9761680686473846, 'test_acc': 0.5915} 




100%|██████████| 500/500 [00:26<00:00, 18.91it/s]


{'epoch': 27, 'train_loss': 0.21692191153764726, 'train_acc': 0.92356, 'test_loss': 1.9653444051742555, 'test_acc': 0.5825} 




100%|██████████| 500/500 [00:26<00:00, 18.65it/s]


{'epoch': 28, 'train_loss': 0.21093580703437328, 'train_acc': 0.92466, 'test_loss': 1.9744904100894929, 'test_acc': 0.5875} 




100%|██████████| 500/500 [00:27<00:00, 18.50it/s]


{'epoch': 29, 'train_loss': 0.19561305993050337, 'train_acc': 0.93004, 'test_loss': 2.0676655089855194, 'test_acc': 0.5827} 




100%|██████████| 500/500 [00:26<00:00, 18.61it/s]


{'epoch': 30, 'train_loss': 0.1889651182293892, 'train_acc': 0.9321, 'test_loss': 2.0787921571731567, 'test_acc': 0.5868} 




100%|██████████| 500/500 [00:25<00:00, 19.70it/s]


{'epoch': 31, 'train_loss': 0.18625081608444452, 'train_acc': 0.93262, 'test_loss': 2.1586960327625273, 'test_acc': 0.5843} 




100%|██████████| 500/500 [00:25<00:00, 19.63it/s]


{'epoch': 32, 'train_loss': 0.17489683453738689, 'train_acc': 0.93774, 'test_loss': 2.178310556411743, 'test_acc': 0.5842} 




100%|██████████| 500/500 [00:27<00:00, 18.29it/s]


{'epoch': 33, 'train_loss': 0.17282871478050948, 'train_acc': 0.93794, 'test_loss': 2.155194458961487, 'test_acc': 0.582} 




100%|██████████| 500/500 [00:26<00:00, 18.83it/s]


{'epoch': 34, 'train_loss': 0.15410504460334778, 'train_acc': 0.94484, 'test_loss': 2.156431450843811, 'test_acc': 0.5866} 




100%|██████████| 500/500 [00:25<00:00, 19.23it/s]


{'epoch': 35, 'train_loss': 0.16603396837413312, 'train_acc': 0.94028, 'test_loss': 2.2372057044506075, 'test_acc': 0.5847} 




100%|██████████| 500/500 [00:27<00:00, 18.45it/s]


{'epoch': 36, 'train_loss': 0.16039956647902728, 'train_acc': 0.9438, 'test_loss': 2.1539224863052366, 'test_acc': 0.5826} 




100%|██████████| 500/500 [00:27<00:00, 18.29it/s]


{'epoch': 37, 'train_loss': 0.14732290572300552, 'train_acc': 0.9477, 'test_loss': 2.276489591598511, 'test_acc': 0.5847} 




100%|██████████| 500/500 [00:26<00:00, 19.14it/s]


{'epoch': 38, 'train_loss': 0.14363413087278604, 'train_acc': 0.9494, 'test_loss': 2.282418601512909, 'test_acc': 0.583} 




100%|██████████| 500/500 [00:27<00:00, 18.51it/s]


{'epoch': 39, 'train_loss': 0.15354240763932467, 'train_acc': 0.94588, 'test_loss': 2.2342197096347807, 'test_acc': 0.5837} 




100%|██████████| 500/500 [00:26<00:00, 18.58it/s]


{'epoch': 40, 'train_loss': 0.1346971834488213, 'train_acc': 0.95268, 'test_loss': 2.3644027376174925, 'test_acc': 0.5818} 




100%|██████████| 500/500 [00:26<00:00, 18.86it/s]


{'epoch': 41, 'train_loss': 0.14937852432578802, 'train_acc': 0.94736, 'test_loss': 2.2096882498264314, 'test_acc': 0.5885} 




100%|██████████| 500/500 [00:26<00:00, 19.17it/s]


{'epoch': 42, 'train_loss': 0.13353882147371768, 'train_acc': 0.95208, 'test_loss': 2.3283363914489748, 'test_acc': 0.5837} 




100%|██████████| 500/500 [00:26<00:00, 19.18it/s]


{'epoch': 43, 'train_loss': 0.13673138508945704, 'train_acc': 0.95192, 'test_loss': 2.3556365168094633, 'test_acc': 0.586} 




100%|██████████| 500/500 [00:26<00:00, 18.57it/s]


{'epoch': 44, 'train_loss': 0.12862910659611226, 'train_acc': 0.95512, 'test_loss': 2.335206027030945, 'test_acc': 0.5889} 




100%|██████████| 500/500 [00:27<00:00, 18.39it/s]


{'epoch': 45, 'train_loss': 0.13202239914610983, 'train_acc': 0.95408, 'test_loss': 2.2981023943424224, 'test_acc': 0.5865} 




100%|██████████| 500/500 [00:26<00:00, 18.77it/s]


{'epoch': 46, 'train_loss': 0.12793906078860165, 'train_acc': 0.95578, 'test_loss': 2.393720734119415, 'test_acc': 0.5846} 




100%|██████████| 500/500 [00:27<00:00, 18.44it/s]


{'epoch': 47, 'train_loss': 0.13030622460320593, 'train_acc': 0.9542, 'test_loss': 2.36507269859314, 'test_acc': 0.5861} 




100%|██████████| 500/500 [00:27<00:00, 18.31it/s]


{'epoch': 48, 'train_loss': 0.1129658790640533, 'train_acc': 0.95994, 'test_loss': 2.4158928644657136, 'test_acc': 0.5929} 




100%|██████████| 500/500 [00:27<00:00, 18.31it/s]


{'epoch': 49, 'train_loss': 0.14208288462087512, 'train_acc': 0.95054, 'test_loss': 2.2960168993473054, 'test_acc': 0.5817} 




100%|██████████| 500/500 [00:26<00:00, 18.83it/s]


{'epoch': 50, 'train_loss': 0.12622795451804997, 'train_acc': 0.95592, 'test_loss': 2.3087862718105314, 'test_acc': 0.5933} 




In [7]:
import plotly.express as px

history_df = pd.read_csv("./history_convt.csv")

fig = px.line(history_df, x="epoch", y=["train_loss", "test_loss"])
fig.show()

fig = px.line(history_df, x="epoch", y=["train_acc", "test_acc"])
fig.show()