In [17]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from tqdm import tqdm
import numpy as np
import sklearn.metrics as metrics

## Hyperparameters and Model Architecture

In [59]:
WEIGHT_DECAY = 0.0
LEARNING_RATE = 0.001
BATCH_SIZE = 10
# CRITERION = F.nll_loss
CRITERION = nn.CrossEntropyLoss()
EPOCHS = 3
NUM_BLOCKS = 2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [51]:
class MSA(nn.Module):
    def __init__(self, d, n_heads=2):
        super(MSA, self).__init__()
        self.d = d
        self.n_heads = n_heads
        d_head = int(d / n_heads)
        self.q_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.k_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.v_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.d_head = d_head
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, sequences):
        result = []
        for sequence in sequences:
            seq_result = []
            for head in range(self.n_heads):
                q_mapping = self.q_mappings[head]
                k_mapping = self.k_mappings[head]
                v_mapping = self.v_mappings[head]
                seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]
                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)
                attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
                seq_result.append(attention @ v)
            result.append(torch.hstack(seq_result))
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])

In [52]:
class ViTBlock(nn.Module):
    def __init__(self, hidden_d, n_heads, mlp_ratio=4):
        super(ViTBlock, self).__init__()
        self.hidden_d = hidden_d
        self.n_heads = n_heads
        self.norm1 = nn.LayerNorm(hidden_d)
        self.mhsa = MSA(hidden_d, n_heads)
        self.norm2 = nn.LayerNorm(hidden_d)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_d, mlp_ratio * hidden_d),
            nn.GELU(),
            nn.Linear(mlp_ratio * hidden_d, hidden_d)
        )

    def forward(self, x):
        out = x + self.mhsa(self.norm1(x))
        out = out + self.mlp(self.norm2(out))
        return out

In [53]:
class ViT(nn.Module):
    def __init__(self, dims, n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10):
        super(ViT, self).__init__()
        self.dims = dims # ( C , H , W )
        self.n_patches = n_patches
        self.n_blocks = n_blocks
        self.n_heads = n_heads
        self.hidden_d = hidden_d
        self.patch_size = (dims[1] / n_patches, dims[2] / n_patches)
        self.input_d = int(dims[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)
        self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))
        self.register_buffer('positional_embeddings', self.get_positional_embeddings(n_patches ** 2 + 1, hidden_d), persistent=False)
        self.blocks = nn.ModuleList([ViTBlock(hidden_d, n_heads) for _ in range(n_blocks)])
        self.mlp = nn.Sequential(
            nn.Linear(self.hidden_d, out_d),
            nn.Softmax(dim=-1)
        )
        
    def patch(self, images, n_patches):
        n, c, h, w = images.shape
        patches = torch.zeros(n, n_patches ** 2, h * w * c // n_patches ** 2)
        patch_size = h // n_patches
        for idx, image in enumerate(images):
            for i in range(n_patches):
                for j in range(n_patches):
                    patch = image[:, i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size]
                    patches[idx, i * n_patches + j] = patch.flatten()
        return patches
    
    def get_positional_embeddings(self, sequence_length, d):
        result = torch.ones(sequence_length, d)
        for i in range(sequence_length):
            for j in range(d):
                result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
        return result

    def forward(self, images):
        n, c, h, w = images.shape
        patches = self.patch(images, self.n_patches).to(self.positional_embeddings.device)
        tokens = self.linear_mapper(patches)
        tokens = torch.cat((self.class_token.expand(n, 1, -1), tokens), dim=1)
        out = tokens + self.positional_embeddings.repeat(n, 1, 1)
        for block in self.blocks:
            out = block(out)
        out = out[:, 0]
        return self.mlp(out)

### MNIST

In [57]:
train = datasets.MNIST("./347data", train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))
test = datasets.MNIST("./347data", train=False, download=True, transform=transforms.Compose([transforms.ToTensor()]))
validation_set_size = int(len(train) * 0.1)
training_set_size = len(train) - validation_set_size
train_set, validation_set = torch.utils.data.random_split(train, [training_set_size, validation_set_size])
train_set = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
validation_set = torch.utils.data.DataLoader(validation_set, batch_size=BATCH_SIZE, shuffle=True)
test_set = torch.utils.data.DataLoader(validation_set, batch_size=BATCH_SIZE, shuffle=True)

In [60]:
MNIST_vit = ViT((1, 28, 28), n_patches=7, n_blocks=NUM_BLOCKS, hidden_d=8, n_heads=2, out_d=10).to(device)
optimizer = optim.Adam(MNIST_vit.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
for epoch in range(EPOCHS):
    for data in tqdm(train_set):
        X, y = data
        MNIST_vit.zero_grad()
        output = MNIST_vit(X.to(device))
        loss = CRITERION(output, y.to(device))
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch + 1} Loss: {loss.item()}")

100%|██████████| 5400/5400 [06:18<00:00, 14.27it/s]
  0%|          | 2/5400 [00:00<06:11, 14.55it/s]

Epoch 1 Loss: 1.7676780223846436


100%|██████████| 5400/5400 [04:54<00:00, 18.33it/s]
  0%|          | 2/5400 [00:00<05:44, 15.65it/s]

Epoch 2 Loss: 1.853771448135376


100%|██████████| 5400/5400 [04:57<00:00, 18.14it/s]

Epoch 3 Loss: 1.7965999841690063





In [61]:
output = []
true = []
MNIST_vit.eval()
with torch.no_grad():
    for data in validation_set:
        X, y = data
        for i in MNIST_vit(X.to(device)):
            output.append(torch.argmax(i).cpu())
        for i in y:
            true.append(i)
MNIST_vit.train()
print("Validation Accuracy:", metrics.accuracy_score(true, output))
print("Validation F1 Score:", metrics.f1_score(true, output, average="macro"))
true = np.eye(10)[true]
output = np.eye(10)[output]
print("Validation AUC Score:", metrics.roc_auc_score(true, output, multi_class="ovo", average="macro"))

Validation Accuracy: 0.7481666666666666
Validation F1 Score: 0.7420055553043702
Validation AUC Score: 0.8602653796811438


### CIFAR-10

In [62]:
train = datasets.CIFAR10("./347data", train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))
test = datasets.CIFAR10("./347data", train=False, download=True, transform=transforms.Compose([transforms.ToTensor()]))
validation_set_size = int(len(train) * 0.1)
training_set_size = len(train) - validation_set_size
train_set, validation_set = torch.utils.data.random_split(train, [training_set_size, validation_set_size])
train_set = torch.utils.data.DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)
validation_set = torch.utils.data.DataLoader(validation_set, batch_size=BATCH_SIZE, shuffle=True)
test_set = torch.utils.data.DataLoader(test, batch_size=BATCH_SIZE, shuffle=True)

Files already downloaded and verified
Files already downloaded and verified


In [64]:
CIFAR10_vit = ViT((3, 32, 32), n_patches=7, n_blocks=NUM_BLOCKS, hidden_d=8, n_heads=2, out_d=10).to(device)
optimizer = optim.Adam(MNIST_vit.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
for epoch in range(EPOCHS):
    for data in tqdm(train_set):
        X, y = data
        CIFAR10_vit.zero_grad()
        output = CIFAR10_vit(X.to(device))
        loss = CRITERION(output, y.to(device))
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch + 1} Loss: {loss.item()}")

  0%|          | 0/5000 [00:00<?, ?it/s]


RuntimeError: The expanded size of the tensor (62) must match the existing size (48) at non-singleton dimension 0.  Target sizes: [62].  Tensor sizes: [48]

In [None]:
output = []
true = []
CIFAR10_vit.eval()
with torch.no_grad():
    for data in validation_set:
        X, y = data
        for i in CIFAR10_vit(X.to(device)):
            output.append(torch.argmax(i).cpu())
        for i in y:
            true.append(i)
CIFAR10_vit.train()
print("Validation Accuracy:", metrics.accuracy_score(true, output))
print("Validation F1 Score:", metrics.f1_score(true, output, average="macro"))
true = np.eye(10)[true]
output = np.eye(10)[output]
print("Validation AUC Score:", metrics.roc_auc_score(true, output, multi_class="ovo", average="macro"))
            

### Iyer

In [None]:
iyer = np.loadtxt(open("347data/iyer.txt", "rb"), delimiter="\t")
features = iyer[:, 2:].astype(float)
labels = iyer[:, 1].astype(int)

data = [] #data0
for i in range(features.shape[0]):
    stack = np.array([])
    stack = np.column_stack([np.roll(features[i,], j, axis=0) for j in range(features.shape[1])]).astype(np.float32)
    data.append([stack, np.array(labels[i] + 1, dtype=int)])

data = np.array(data)
np.random.shuffle(data)
X = torch.tensor(np.array([i[0] for i in data])).view(-1, 1, 12, 12)
y = torch.tensor(np.array([i[1] for i in data]))

test_set_size = int(X.shape[0] * 0.1)
training_set_size = X.shape[0] - test_set_size
validation_set_size = int(training_set_size * 0.1)
training_set_size -= validation_set_size
print(training_set_size, validation_set_size, test_set_size)
train_X = X[:training_set_size]
train_y = y[:training_set_size]
validation_X = X[training_set_size:training_set_size + validation_set_size]
validation_y = y[training_set_size:training_set_size + validation_set_size]
test_X = X[training_set_size + validation_set_size:]
test_y = y[training_set_size + validation_set_size:]