In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
# from torchsummary import summary

import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.model_selection import train_test_split
import glob, random, os, warnings

In [None]:
def seed_everything(seed = 0):
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['TF_DETERMINISTIC_OPS'] = '1'

seed_everything()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
img = Image.open(r'C:\Users\MOJAHID HUSSAIN\Desktop\vit\cassava-leaf-disease-classification\train_images\6103.jpg')
fig = plt.figure()
plt.imshow(img)

In [None]:
# resize to imagenet size 
transform = Compose([Resize((224, 224)), ToTensor()])
x = transform(img)
x = x.unsqueeze(0) # add batch dim
x.shape

In [None]:
patch_size = 16 # 16 pixels
pathes = rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size)
# print(pathes.size())

In [None]:
# Path
Data_path=r'C:\Users\MOJAHID HUSSAIN\Desktop\vit\cassava-leaf-disease-classification'
training_images=r'C:\Users\MOJAHID HUSSAIN\Desktop\vit\cassava-leaf-disease-classification\train_images'
training_path=r'C:\Users\MOJAHID HUSSAIN\Desktop\vit\cassava-leaf-disease-classification\train.csv'
testing_path=r'C:\Users\MOJAHID HUSSAIN\Desktop\vit\cassava-leaf-disease-classification\test_images'

In [None]:
# Let look at the training data
wholeData=os.path.join(Data_path, "train.csv")
wholeData=pd.read_csv(wholeData)
wholeData.head()

# stratify ensures that the proportion of different classes in the original dataset is maintained in both the training and testing subsets.
training_data, valid_data=train_test_split(wholeData, test_size=0.2, random_state=42, stratify=wholeData.label.values) 
valid_data.head()

In [None]:
# split the data into training and validation data
import matplotlib.pyplot as plt
import seaborn as sns
wholeData.label.value_counts().plot(kind='bar', color='blue', position=0, width=0.20, label='Whole Data')
training_data.label.value_counts().plot(kind='bar', color='orange', position=1, width=0.20, label='Training Data')
valid_data.label.value_counts().plot(kind='bar', color='green', position=2, width=0.20, label='Validation Data')
plt.legend()
plt.show()

In [None]:
from PIL import Image

class CassavaDataset(torch.utils.data.Dataset):
    def __init__(self, wholeData, data_path=Data_path, transform=None, mode="train"):
        super().__init__()
        self.data=wholeData.values  #Covert df into 2D array
        self.data_path=data_path
        self.transform=transform
        self.mode=mode
        self.data_dir= "train_images" if mode=="train" else "test_images"
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        image, label=self.data[index]
        image=os.path.join(self.data_path, self.data_dir, image);
        OpenImage=Image.open(image).convert("RGB")
        
        if self.transform is not None:
            OpenImage=self.transform(OpenImage)
        
        return OpenImage, label

In [None]:
image_size=224
from torchvision.transforms import transforms
transformsTrain=transforms.Compose(
    [
        transforms.Resize((image_size, image_size)),
        transforms.RandomHorizontalFlip(p=0.3),
        transforms.RandomVerticalFlip(p=0.3),
        transforms.RandomRotation(10),
        transforms.RandomAffine(100),
        transforms.RandomResizedCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ]
)

transformsValid = transforms.Compose(
    [
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ]
)

In [None]:
train_dataset= CassavaDataset(training_data, transform=transformsTrain) #train_dataset is a tensor array that contain all the training images at index 0 and its label at index 1 a/q to index
valid_dataset = CassavaDataset(valid_data, transform=transformsValid)
print(train_dataset[14][0].shape)

In [None]:
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=48,
    shuffle=True,
    num_workers=4,
)

valid_loader = torch.utils.data.DataLoader(
    dataset=valid_dataset,
    batch_size=48,
    shuffle=False,
    num_workers=4,
)

In [None]:
class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
        
    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x

In [None]:
class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768, img_size: int = 224):
        self.patch_size = patch_size
        super().__init__()
        self.projection = nn.Sequential(
            # using a conv layer instead of a linear one -> performance gains
            nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
            Rearrange('b e (h) (w) -> b (h w) e'),
        )
        self.cls_token = nn.Parameter(torch.randn(1,1, emb_size))
        self.positions = nn.Parameter(torch.randn((img_size // patch_size) **2 + 1, emb_size))

        
    def forward(self, x: Tensor) -> Tensor:
        b, _, _, _ = x.shape
        x = self.projection(x)
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
        # prepend the cls token to the input
        x = torch.cat([cls_tokens, x], dim=1) #concatenation
        # add position embedding
        x += self.positions
        return x
    
PatchEmbedding()(x).shape

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        # fuse the queries, keys and values in one matrix
        self.qkv = nn.Linear(emb_size, emb_size * 3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)
        
    def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
        # split keys, queries and values in num_heads
        qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
        queries, keys, values = qkv[0], qkv[1], qkv[2]
        # sum up over the last axis
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)
            
        scaling = self.emb_size ** (1/2)
        att = F.softmax(energy, dim=-1) / scaling
        att = self.att_drop(att)
        # sum up over the third axis
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out
    
patches_embedded = PatchEmbedding()(x)
MultiHeadAttention()(patches_embedded).shape

In [None]:
class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size: int = 768,
                 drop_p: float = 0.,
                 forward_expansion: int = 4,
                 forward_drop_p: float = 0.,
                 ** kwargs):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, **kwargs),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
            ))

In [None]:
patches_embedded = PatchEmbedding()(x)
TransformerEncoderBlock()(patches_embedded).shape

In [None]:
class TransformerEncoder(nn.Sequential):
    def __init__(self, depth: int = 12, **kwargs):
        super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])

In [None]:
class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size: int = 768, n_classes: int = 5):
        super().__init__(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size), 
            nn.Linear(emb_size, n_classes))

In [None]:
class ViT(nn.Sequential):
    def __init__(self,     
                in_channels: int = 3,
                patch_size: int = 16,
                emb_size: int = 768,
                img_size: int = 224,
                depth: int = 12,
                n_classes: int = 5,
                **kwargs):
        super().__init__(
            PatchEmbedding(in_channels, patch_size, emb_size, img_size),
            TransformerEncoder(depth, emb_size=emb_size, **kwargs),
            ClassificationHead(emb_size, n_classes)
        )

In [None]:
model=ViT().to(device)
print(model(x.to(device)))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.1)

In [None]:
batch_size=48
def train_one_epoch(epoch_index):
    running_loss = 0.
    last_loss = 0.
    train_acc = 0.
    
    for i, (data, target) in enumerate(train_loader):
        inputs, labels = data.to(device), target.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        train_acc += torch.sum(outputs.argmax(dim=1).to(device) == labels)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if (i+1) % 20 == 0:
            last_loss = running_loss / 20 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
#             tb_x = epoch_index * len(train_loader) + i + 1
            running_loss = 0.

    return train_acc/((len(train_loader)*batch_size))

In [None]:
for i in range(10):
    print('Epoch: {}'.format(i + 1))
    out=train_one_epoch(i)
    print("Final:", out)

In [None]:
def valid_check():
    running_loss = 0.
    last_loss = 0.
    train_acc = 0.
    
    for i, (data, target) in enumerate(valid_loader):
        inputs, labels = data.to(device), target.to(device)
        outputs = model(inputs)
        temp_acc = torch.sum(outputs.argmax(dim=1).to(device) == labels)
        train_acc+=temp_acc
        if i%10==0:
            print('  batch {} correct: {}'.format(i + 1, (train_acc/(batch_size*(i+1)))))

    return train_acc/(len(valid_loader)*batch_size)

In [None]:
print("Final Accuracy:", valid_check())

In [None]:
# Model saving
epoch=5
PATH = r"C:\Users\MOJAHID HUSSAIN\Desktop\vit\model_checkpoint.pth"
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': criterion,
            }, PATH)

In [None]:

model = ViT()
# optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
criterion = checkpoint['loss']

model.eval()