In [2]:
import torch
from torch import nn

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import os
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm

In [4]:
import zipfile

zip_file_path = '/content/drive/MyDrive/GitProjects/archive (10).zip'

extract_to_directory = ''

with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
    zip_ref.extractall(extract_to_directory)

print("Unzipping complete.")

Unzipping complete.


In [5]:
train_dataset_dir = 'seg_train/seg_train'
val_dataset_dir = 'seg_test/seg_test'

In [53]:
target_width, target_height = 224, 224

image_size = (target_width, target_height)
d_model = 512
block_size = 4
d_ff = 2048
N = 6
n_heads = 8
num_classes = 6
learning_rate = 3e-4
epochs = 25
batch_size = 32
device = "cuda" if torch.cuda.is_available() else "cpu"

In [54]:
data_transform = transforms.Compose([
    transforms.Resize((target_width, target_height)),
    transforms.ToTensor(),
])

In [55]:
train_dataset = ImageFolder(root=train_dataset_dir, transform=data_transform)
val_dataset = ImageFolder(root=val_dataset_dir, transform=data_transform)

In [56]:

train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_data_loader = DataLoader(val_dataset, batch_size=1, shuffle=True)
for batch in train_data_loader:
    images, labels = batch
    images = images.to(device)
    labels = labels.to(device)
    break

In [57]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class EmbedCNN(nn.Module):
    def __init__(self, input_shape, d_dim):
        super(EmbedCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)

        self.fc_input_size = self.calculate_fc_input_size(input_shape)
        # print(self.fc_input_size)
        self.fc1 = nn.Linear(self.fc_input_size, d_dim)

    def calculate_fc_input_size(self, input_shape):
        x = torch.randn(1, *input_shape)
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = F.relu(self.conv3(x))
        x = self.pool(x)
        return x.view(1, -1).size(1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = F.relu(self.conv3(x))
        x = self.pool(x)

        x = x.view(-1, self.fc_input_size)
        # print(x.shape)
        x = self.fc1(x)

        return x


In [58]:
class InputEmbeddings(nn.Module):
    def __init__(self, image_size, num_parts, d_model):
        super().__init__()
        self.image_size = image_size
        self.num_parts = num_parts
        self.d_model = d_model
        x , y = image_size
        self.x = int(x/num_parts)
        self.y = int(y/num_parts)
        self.embedding = EmbedCNN((3, self.x, self.y), d_model).to(device)

    def forward(self,images):
        result = []
        pos = []
        for i in range(self.num_parts):
            for j in range(self.num_parts):
                part = images[:, :, i * self.x  : (i+1) * self.x ,
                                   j * self.y  : (j+1) * self.y]
                # print(part.shape)
                result.append(self.embedding(part.to(device)))
                pos.append((i, j))
        return torch.cat([i.unsqueeze(1) for i in result], dim =1).to(device), pos

In [59]:
class ClassEmbeddings(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.embedding = nn.Embedding(1 , d_model)
    def forward(self, x):
        return self.embedding(x)

In [60]:
class PositionalEmbeddings(nn.Module):
    def __init__(self, block_size, d_model):
        super().__init__()
        self.d_model = d_model
        self.block_size = block_size
        self.pos_embedding_x = nn.Embedding(block_size, d_model)
        self.pos_embedding_y = nn.Embedding(block_size, d_model)
    def forward(self, x , y):
        return self.pos_embedding_x(x) + self.pos_embedding_y(y)

class Embedding(nn.Module):
    def __init__(self, image_size, num_parts, d_model):
        super().__init__()
        self.inp_embed = InputEmbeddings(image_size, num_parts, d_model)
        self.pos_embed = PositionalEmbeddings(num_parts, d_model)
        self.class_embed = ClassEmbeddings(d_model)
        self.n = num_parts

    def forward(self, x):
        n = self.n
        pos_yinp = torch.cat([torch.arange(n) for _ in range(n)]).to(device)
        # print(pos_yinp)
        pos_xinp = torch.cat([torch.tensor([i for _ in range(n)]) for i in range(n)]).to(device)
        # print(pos_xinp)
        embedded = self.inp_embed(x)[0] + self.pos_embed(pos_xinp, pos_yinp).unsqueeze(0)
        class_embed = self.class_embed(torch.tensor([0]).to(device))
        B = embedded.shape[0]
        class_embed = torch.cat([class_embed.unsqueeze(0) for _ in range(B)]).to(device)
        return torch.cat([class_embed, embedded], dim = 1)

In [61]:
embeddings = Embedding(image_size, block_size, d_model).to(device)

In [62]:
embeddings(images).shape

torch.Size([32, 17, 512])

In [63]:
class Attention(nn.Module):
    def __init__(self, d_model, d, dropout=0.2):
        super().__init__()
        self.d_model = d_model
        self.d = d
        self.Q = nn.Linear(d_model, d)
        self.K = nn.Linear(d_model, d)
        self.V = nn.Linear(d_model, d)
        # self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)).to(device))
        self.dropout = nn.Dropout(dropout)
    def forward(self,x): #inp --> (64, 256, d_model)
        q = self.Q(x) #(64, 256, d)
        k = self.K(x) #(64, 256, d)
        v = self.V(x) #(64, 256, d)
        T = x.shape[1]
        weights = q@k.transpose(-2,-1)*k.shape[-1]**(-0.5) #(64, 256, 256)
        # weights = weights.masked_fill(self.tril[:T,:T] == 0, float('-inf'))
        weights = F.softmax(weights, dim = -1) #(64, 256, 256)
        out = weights @ v
        return out #(64, 256, d)


In [64]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout = 0.2):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.heads = nn.ModuleList([Attention(d_model, d_model//n_heads) for _ in range(n_heads)])
        self.proj = nn.Linear(n_heads * (d_model//n_heads) , d_model)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

In [65]:
class FeedForwardBlock(nn.Module):
    def __init__(self, d_model, d_ff, dropout = 0.2):
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff
        self.dropout = nn.Dropout(dropout)
        self.W1 = nn.Linear(d_model, d_ff)
        self.W2 = nn.Linear(d_ff, d_model)
    def forward(self, x):
        out = self.W1(x)
        out = F.relu(out)
        out = self.dropout(self.W2(out))
        return out


In [66]:
class EncoderBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_ff = d_ff
        self.multi_attention = MultiHeadAttention(d_model, n_heads)
        self.ffb = FeedForwardBlock(d_model, d_ff)
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)

    def forward(self, x):
        out = self.multi_attention(x)
        out1 = x + self.ln1(out)
        out2 = self.ffb(out1)
        final_out = out1 + self.ln2(out2)
        return final_out

In [67]:
class VisionTransformer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, image_size, block_size, N, num_classes):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_ff = d_ff
        self.num_classes = num_classes
        self.image_size = image_size
        self.block_size = block_size
        self.inp_embed = Embedding(image_size, block_size, d_model)
        # self.pos_embed = PositionalEmbeddings(block_size, d_model)
        self.decoder_blocks = nn.ModuleList([EncoderBlock(d_model, n_heads, d_ff) for _ in range(N)])
        self.proj = nn.Linear(d_model, num_classes)


    def forward(self, x, targets = None):
        x = self.inp_embed(x)
        # block_size = x.shape[1]
        # x = x + self.pos_embed(torch.arange(block_size).to(device))
        for block in self.decoder_blocks:
            x = block(x)
        logits = self.proj(x[:,0,:])
        # logits = logits.mean(dim = 1)
        if targets is None:
            loss = None
        else:
            B, C = logits.shape
            logits = logits.view(B, C)
            targets = targets.view(B)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

In [68]:
model = VisionTransformer(d_model, n_heads, d_ff, image_size, block_size, N, num_classes).to(device)

In [69]:
sum(p.numel() for p in model.parameters() if p.requires_grad)/(1000000)

20.584454

In [70]:
model(images, labels)[0].shape

torch.Size([32, 6])

In [71]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [72]:
@torch.no_grad()

def calculate_loss(model, data_loader):
    out = {}
    model.eval()
    losses = torch.zeros(len(data_loader))
    true_pred = 0
    total_pred = 0
    for i, (batch, labels) in enumerate(data_loader):
        logits, loss = model(batch.to(device), labels.to(device))
        losses[i] = loss.item()
    return losses.mean().item()

In [73]:
from tqdm import tqdm
for epoch in range(epochs):
    losses = torch.zeros(len(train_data_loader))
    i = 0
    if epoch == 0:
        train_loss = calculate_loss(model, train_data_loader)
        val_loss = calculate_loss(model, val_data_loader)
        print(f"training and validation loss after epoch {epoch} is {train_loss} , {val_loss}")
    model.train()
    pred_trues = 0
    for batch, labels in tqdm(train_data_loader):
        batch = batch.to(device)
        labels = labels.to(device)
        logits, loss = model(batch, labels)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        losses[i] = loss.item()
        i+=1
        y_1 = logits
        y = labels
        pred_trues += sum(y == y_1.argmax(dim=1))
    train_loss = losses.mean()
    val_loss = calculate_loss(model, val_data_loader)
    train_acc = pred_trues/(len(train_data_loader)*batch_size)
    pred_true = 0
    for x, y in val_data_loader:
        x = x.to(device)
        y = y.to(device)
        y_1 = model(x)[0]
        pred_true += (y == y_1.argmax()).item()
    val_acc = pred_true/len(val_data_loader)

    print(f"training and validation loss after epoch {epoch + 1} is {train_loss} , {val_loss}")
    print(f"training and validation accuracy after epoch {epoch + 1} is {train_acc} , {val_acc}")

training and validation loss after epoch 0 is 4.735386371612549 , 4.775509834289551


100%|██████████| 439/439 [01:23<00:00,  5.23it/s]


training and validation loss after epoch 1 is 1.9619485139846802 , 1.067578911781311
training and validation accuracy after epoch 1 is 0.4397067427635193 , 0.5803333333333334


100%|██████████| 439/439 [01:24<00:00,  5.20it/s]


training and validation loss after epoch 2 is 1.0201154947280884 , 0.9110615253448486
training and validation accuracy after epoch 2 is 0.6040005683898926 , 0.6713333333333333


100%|██████████| 439/439 [01:23<00:00,  5.25it/s]


training and validation loss after epoch 3 is 0.9102698564529419 , 0.8412778973579407
training and validation accuracy after epoch 3 is 0.6543280482292175 , 0.676


100%|██████████| 439/439 [01:22<00:00,  5.30it/s]


training and validation loss after epoch 4 is 0.8432556986808777 , 0.8809086680412292
training and validation accuracy after epoch 4 is 0.6742597222328186 , 0.6706666666666666


100%|██████████| 439/439 [01:23<00:00,  5.24it/s]


training and validation loss after epoch 5 is 0.7816968560218811 , 0.720420241355896
training and validation accuracy after epoch 5 is 0.7030894160270691 , 0.732


100%|██████████| 439/439 [01:23<00:00,  5.28it/s]


training and validation loss after epoch 6 is 0.7510222792625427 , 0.7862509489059448
training and validation accuracy after epoch 6 is 0.7191059589385986 , 0.6926666666666667


100%|██████████| 439/439 [01:23<00:00,  5.26it/s]


training and validation loss after epoch 7 is 0.7718472480773926 , 0.8227968811988831
training and validation accuracy after epoch 7 is 0.7084282636642456 , 0.712


100%|██████████| 439/439 [01:20<00:00,  5.43it/s]


training and validation loss after epoch 8 is 0.7756748795509338 , 0.7916551828384399
training and validation accuracy after epoch 8 is 0.7066486477851868 , 0.723


100%|██████████| 439/439 [01:21<00:00,  5.37it/s]


training and validation loss after epoch 9 is 0.7296611070632935 , 0.725904643535614
training and validation accuracy after epoch 9 is 0.7271497845649719 , 0.735


100%|██████████| 439/439 [01:22<00:00,  5.34it/s]


KeyboardInterrupt: 