In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, Dataset

from torchvision.transforms import ToTensor

from tqdm import tqdm, trange

from zipfile import ZipFile

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split 

from numpy.typing import NDArray
from typing import Tuple

np.random.seed(0)
torch.manual_seed(0)

RSEED = 0
BATCH = 32

In [3]:
with ZipFile("Colorectal_Histology.zip", "r") as zObject:
    zObject.extractall(path=".")

In [9]:
df = pd.read_csv("hmnist_28_28_L.csv")
X = df.drop('label', axis=1)
X = X.to_numpy()
y = df.label.to_numpy()


In [17]:
df.head(10)

Unnamed: 0,pixel0000,pixel0001,pixel0002,pixel0003,pixel0004,pixel0005,pixel0006,pixel0007,pixel0008,pixel0009,...,pixel0775,pixel0776,pixel0777,pixel0778,pixel0779,pixel0780,pixel0781,pixel0782,pixel0783,label
0,101,110,154,160,95,44,139,184,164,160,...,103,73,72,75,152,130,96,133,159,2
1,67,66,69,76,80,57,46,67,90,77,...,58,65,74,80,81,83,77,75,73,2
2,127,137,121,140,170,111,128,117,60,105,...,90,100,143,119,148,140,193,146,97,2
3,80,90,101,106,120,100,99,66,63,91,...,131,109,97,102,71,93,120,84,62,2
4,153,141,121,132,110,131,119,99,101,91,...,117,121,136,178,192,210,189,149,155,2
5,120,102,83,73,74,77,86,88,89,93,...,128,103,104,106,128,152,141,133,107,2
6,187,187,189,199,198,179,177,174,179,166,...,140,148,142,129,132,146,125,160,207,2
7,73,98,126,153,116,45,89,124,108,89,...,98,101,125,147,115,156,170,114,90,2
8,124,132,130,129,141,146,125,117,114,107,...,126,120,122,127,119,116,143,157,138,2
9,141,176,171,173,205,229,236,238,231,239,...,158,158,208,220,209,171,131,129,131,2


In [10]:
type(y)

numpy.ndarray

In [11]:
X

array([[101, 110, 154, ...,  96, 133, 159],
       [ 67,  66,  69, ...,  77,  75,  73],
       [127, 137, 121, ..., 193, 146,  97],
       ...,
       [ 27,  50,  94, ..., 223, 149,  77],
       [108, 113, 116, ..., 132,  93,  83],
       [ 67,  74,  67, ..., 121,  92,  77]], shape=(5000, 784))

In [18]:
class histo(Dataset):
    def __init__(self, X:NDArray[np.int8], y:NDArray[np.int8])->None:
        X = torch.tensor(X, dtype=torch.uint8).view(-1, 1, 28, 28)
        self.X = X.float()/255
        if y is None:
            self.y = None
        else:
            self.y = y

    def __len__(self)->int:
        return len(self.X)
    
    def __getitem__(self, idx:int)->int:
        return self.X[idx], self.y[idx]
    
    @staticmethod
    def create_split(X, y, train_fraction:float, val_fraction:float, test_fraction:float)->Tuple[Dataset, Dataset, Dataset]:
        assert train_fraction + val_fraction + test_fraction <= 1

        X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=val_fraction)

        train_data = histo(X_train, y_train)
        val_data = histo(X_val, y_val)

        return train_data, val_data


In [19]:
train_data, val_data = histo.create_split(X, y, 0.5, 0.5,0.0)

In [20]:
train_data[25]

(tensor([[[0.3569, 0.3137, 0.2549, 0.2078, 0.0902, 0.0941, 0.0824, 0.0706,
           0.2471, 0.3412, 0.2118, 0.1686, 0.1765, 0.1882, 0.2275, 0.2471,
           0.2118, 0.2235, 0.2392, 0.2392, 0.1686, 0.1843, 0.2039, 0.1961,
           0.3020, 0.2667, 0.2196, 0.3373],
          [0.3137, 0.2863, 0.2627, 0.2471, 0.0863, 0.1020, 0.1098, 0.0863,
           0.3059, 0.4314, 0.2157, 0.1725, 0.1059, 0.2078, 0.2667, 0.2510,
           0.2784, 0.2706, 0.2235, 0.2275, 0.1412, 0.1882, 0.3451, 0.3294,
           0.3176, 0.2784, 0.2706, 0.2510],
          [0.3216, 0.2667, 0.3529, 0.1922, 0.0980, 0.0863, 0.0824, 0.0941,
           0.2510, 0.2824, 0.2000, 0.2157, 0.2275, 0.2667, 0.2471, 0.1529,
           0.1843, 0.2667, 0.2902, 0.2588, 0.2667, 0.4353, 0.4353, 0.3608,
           0.2941, 0.3020, 0.2980, 0.3176],
          [0.2745, 0.2863, 0.2627, 0.1373, 0.1216, 0.0784, 0.0392, 0.2745,
           0.4667, 0.2157, 0.1725, 0.1529, 0.1608, 0.1922, 0.2157, 0.2784,
           0.3647, 0.3333, 0.3529, 0.2980, 

In [21]:
def  main():


    train_loader = DataLoader(train_data, batch_size=BATCH, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=BATCH, shuffle=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = ViTModel((1, 28, 28), n_patches=16, n_blocks=2, hidden_d=8, n_heads=2, out_d=10).to(device)
    num_epochs = 4
    learn_rate = 0.005

    optimizer = Adam(model.parameters(), lr=learn_rate)
    loss_fn = CrossEntropyLoss()

    for epoch in trange(num_epochs, desc="Training"):
        train_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch +1} in training", leave=False):
            X, y = batch
            X, y = X.to(device), y.to(device)
            y_pred = model(X)
            loss = loss_fn(y_pred, y)

            train_loss += loss.detach().cpu().item()/len(train_loader)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}/{num_epochs} loss: {train_loss:.3f}")

        with torch.no_grad():
            correct, total = 0,  0

            test_loss = 0.0

            for batch in tqdm(val_loader, desc="Testing"):
                X, y = batch
                X, y = X.to(device), y.to(device)

                y_pred = model(X)

                loss = loss_fn(y_pred, y)

                test_loss += loss.detach().cpu().item()/len(val_loader)

                correct += torch.sum(torch.argmax(y_pred, dim=1)==y).detach().cpu().item()
                total += len(X)
            print(f"Test loss: {test_loss:.3f}")
            print(f"Test accuracy: {correct/total*100:.3f}")



    

In [24]:
class ViTModel(nn.Module):
    def __init__(self):
        super(ViTModel, self).__init__()

    def foward(self):
        pass

In [25]:
def make_patches(images, num_patches):
    n, c, h, w = images.shape

    assert h == w

    patches = torch.zeros(n, num_patches**2, h*w*c//num_patches**2)
    patch_size = h//num_patches

    for idx, image in enumerate(images):
        for i in num_patches:
            for j in num_patches:
                patch = image[:, i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size]
                patches[idx, i*num_patches+j] = patch.flatten()
    return patches