In [17]:
import numpy as np
import torch
import random
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchmetrics.classification import Accuracy
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from tqdm import tqdm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader,Subset,random_split
from vit_pytorch import ViT
from vit_pytorch.vit import Transformer, Attention

In [18]:
data_path = "filtered_species"
batch_size = 64

transform = transforms.Compose([
    transforms.CenterCrop((256, 256)),  #Resize to minimum of all sizes - Will update size in cnn architecture
    transforms.RandomHorizontalFlip(p = 0.25),
    transforms.RandomRotation(degrees = 30),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])

data = datasets.ImageFolder(root=data_path, transform=transform)


subset_size = 5000
indices = random.sample(range(len(data)), subset_size)
data_subset = Subset(data, indices)

#change data_subset <-> data for train from whole data
train_size = int(0.8 * len(data))
test_size = len(data) - train_size
training, testing = random_split(data, [train_size, test_size], generator=torch.Generator().manual_seed(1111))

training_dataset = DataLoader(training, batch_size=batch_size, shuffle=True, num_workers=6, pin_memory=True)
testing_dataset = DataLoader(testing, batch_size=batch_size, shuffle=True, num_workers=6, pin_memory=True)
#For testing purposes
#print("Class names:", data.classes)
#len(data.classes)

In [19]:
class BirdClassifier(nn.Module):
    def __init__(self,
                 cnn_state=False,                # Whether to use CNN before ViT
                 image_size=256,
                 patch_size=16,
                 num_class=10,
                 dim=256,
                 layer_count=1,
                 head_count=1,
                 transformer_ff_neurons=256,
                 transformer_dropout=0.2):
        super().__init__()
        
        self.cnn_state = cnn_state
        self.image_size = image_size if not cnn_state else 64  # Will update if CNN is used

        if cnn_state:
            self.cnn = nn.Sequential(
                nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),  # AlexNet
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=3, stride=2),
                
                nn.Conv2d(96, 256, kernel_size=5, padding=2),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=3, stride=2),

                nn.Conv2d(256, 384, kernel_size=3, padding=1),
                nn.ReLU(),

                nn.Conv2d(384, 384, kernel_size=3, padding=1),
                nn.ReLU(),

                nn.Conv2d(384, 256, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.AdaptiveAvgPool2d((self.image_size, self.image_size))  # Resize to ViT input
            )

        self.vision_transformer = ViT(
            image_size=self.image_size,
            patch_size=patch_size,
            num_classes=num_class,
            dim=dim,
            depth=layer_count,
            heads=head_count,
            mlp_dim=transformer_ff_neurons,
            dropout=transformer_dropout,
            emb_dropout=transformer_dropout,
            channels=3
        )

    def forward(self, x):
        if self.cnn_state:
            x = self.cnn(x)
        x = self.vision_transformer(x)
        return x

    def print_config(self):
        print(f"Using CNN: {self.cnn_state}")
        print(f"ViT dim: {self.vision_transformer.dim}, layers: {self.vision_transformer.depth}")




In [20]:
def train_model(model, dataloader, criterion, optimizer_metric, accuracy_metric, device):
    model.train()
    net_loss = 0
    for images, labels in tqdm(dataloader, desc = "TRAINIGN"):
        images, labels = images.to(device), labels.to(device)
        y_hat = model(images)
        loss = criterion(y_hat, labels)
        optimizer_metric.zero_grad()
        loss.backward()
        optimizer_metric.step()
        accuracy_metric.update(y_hat, labels)
        net_loss += loss.item()
    
    epoch_accuracy = accuracy_metric.compute().item()
    epoch_loss = net_loss/(len(dataloader))

    return epoch_accuracy, epoch_loss

def test_model(model, dataloader, criterion, accuracy_metric, device):
    model.eval()
    net_loss = 0
    for images, labels in tqdm(dataloader, desc = "TESTING"):
        image, label = images.to(device), labels.to(device)
        y_hat = model(image)
        loss = criterion(y_hat, label)
        accuracy_metric.update(y_hat, label)
        net_loss += loss.item()
    
    epoch_accuracy = accuracy_metric.compute().item()
    epoch_loss = net_loss/(len(dataloader))

    return epoch_accuracy, epoch_loss

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_multihead_8 = BirdClassifier(cnn_state=False, layer_count=4, head_count = 8).to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model_multihead_8.parameters(), lr=3e-4)
accuracy_score = Accuracy(task = 'multiclass', num_classes = 10).to(device)


patience = 3
min_delta = 3e-3
best_accuracy = 0
counter = 0

loss_scores_train_multihead_8 = []
accuracy_scores_train_multihead_8 = []

loss_scores_test_multihead_8 = []
accuracy_scores_test_multihead_8 = []

for epoch in range(200):  
    print(f"\nEpoch {epoch+1}")
    train_acc, train_loss = train_model(model_multihead_8, training_dataset, criterion, optimizer, accuracy_score, device)
    test_acc, test_loss = test_model(model_multihead_8, testing_dataset, criterion, accuracy_score, device)
    
    accuracy_scores_train_multihead_8.append(train_acc)
    loss_scores_train_multihead_8.append(train_loss)

    accuracy_scores_test_multihead_8.append(test_acc)
    loss_scores_test_multihead_8.append(test_loss)
    print(f"Train Acc: {train_acc:.4f} | Train Loss: {train_loss:.4f} | Test Acc: {test_acc:.4f} | Test Loss: {test_loss}")
    
    if test_acc - best_test_acc > min_delta:
        best_test_acc = test_acc
        counter = 0
    else:
        counter += 1
        print(f"No improvement. Early stopping counter: {counter}/{patience}")
        if counter >= patience:
            print("Early stopping triggered.")
            break


Epoch 1


TRAINIGN: 100%|██████████| 587/587 [00:33<00:00, 17.58it/s]
TESTING: 100%|██████████| 147/147 [00:07<00:00, 20.88it/s]


Train Acc: 0.3142 | Train Loss: 1.9243 | Test Acc: 0.3266 | Test Loss: 1.7574983113477018

Epoch 2


TRAINIGN: 100%|██████████| 587/587 [00:32<00:00, 18.11it/s]
TESTING: 100%|██████████| 147/147 [00:06<00:00, 21.22it/s]


Train Acc: 0.3478 | Train Loss: 1.7610 | Test Acc: 0.3519 | Test Loss: 1.8013007884122887

Epoch 3


TRAINIGN: 100%|██████████| 587/587 [00:32<00:00, 18.08it/s]
TESTING: 100%|██████████| 147/147 [00:06<00:00, 21.96it/s]


Train Acc: 0.3670 | Train Loss: 1.6857 | Test Acc: 0.3712 | Test Loss: 1.6218019140009978

Epoch 4


TRAINIGN: 100%|██████████| 587/587 [00:32<00:00, 17.97it/s]
TESTING: 100%|██████████| 147/147 [00:06<00:00, 22.50it/s]


Train Acc: 0.3824 | Train Loss: 1.6397 | Test Acc: 0.3857 | Test Loss: 1.5892163588076222

Epoch 5


TRAINIGN: 100%|██████████| 587/587 [00:32<00:00, 18.32it/s]
TESTING: 100%|██████████| 147/147 [00:06<00:00, 22.44it/s]


Train Acc: 0.3950 | Train Loss: 1.5990 | Test Acc: 0.3974 | Test Loss: 1.567760478071615

Epoch 6


TRAINIGN: 100%|██████████| 587/587 [00:32<00:00, 18.23it/s]
TESTING: 100%|██████████| 147/147 [00:07<00:00, 19.07it/s]


Train Acc: 0.4046 | Train Loss: 1.5653 | Test Acc: 0.4063 | Test Loss: 1.5496198247079136

Epoch 7


TRAINIGN: 100%|██████████| 587/587 [00:32<00:00, 18.06it/s]
TESTING: 100%|██████████| 147/147 [00:06<00:00, 21.12it/s]


Train Acc: 0.4132 | Train Loss: 1.5354 | Test Acc: 0.4151 | Test Loss: 1.5100424265374943

Epoch 8


TRAINIGN: 100%|██████████| 587/587 [00:32<00:00, 18.28it/s]
TESTING: 100%|██████████| 147/147 [00:07<00:00, 20.48it/s]


Train Acc: 0.4209 | Train Loss: 1.5116 | Test Acc: 0.4219 | Test Loss: 1.5585161050160725

Epoch 9


TRAINIGN: 100%|██████████| 587/587 [00:32<00:00, 18.11it/s]
TESTING: 100%|██████████| 147/147 [00:07<00:00, 20.75it/s]


Train Acc: 0.4273 | Train Loss: 1.4848 | Test Acc: 0.4286 | Test Loss: 1.466239907303635

Epoch 10


TRAINIGN: 100%|██████████| 587/587 [00:32<00:00, 18.10it/s]
TESTING: 100%|██████████| 147/147 [00:06<00:00, 21.86it/s]


Train Acc: 0.4337 | Train Loss: 1.4585 | Test Acc: 0.4349 | Test Loss: 1.4541337652271296

Epoch 11


TRAINIGN: 100%|██████████| 587/587 [00:32<00:00, 18.24it/s]
TESTING: 100%|██████████| 147/147 [00:07<00:00, 20.99it/s]


Train Acc: 0.4393 | Train Loss: 1.4460 | Test Acc: 0.4405 | Test Loss: 1.4191117416433736

Epoch 12


TRAINIGN: 100%|██████████| 587/587 [00:32<00:00, 18.22it/s]
TESTING: 100%|██████████| 147/147 [00:07<00:00, 20.55it/s]


Train Acc: 0.4447 | Train Loss: 1.4230 | Test Acc: 0.4458 | Test Loss: 1.4387104989720039

Epoch 13


TRAINIGN: 100%|██████████| 587/587 [00:32<00:00, 18.15it/s]
TESTING: 100%|██████████| 147/147 [00:07<00:00, 20.77it/s]


Train Acc: 0.4498 | Train Loss: 1.4044 | Test Acc: 0.4507 | Test Loss: 1.4480376592298754

Epoch 14


TRAINIGN: 100%|██████████| 587/587 [00:32<00:00, 18.00it/s]
TESTING: 100%|██████████| 147/147 [00:06<00:00, 23.02it/s]


Train Acc: 0.4543 | Train Loss: 1.3946 | Test Acc: 0.4553 | Test Loss: 1.3918831299762338

Epoch 15


TRAINIGN: 100%|██████████| 587/587 [00:31<00:00, 18.37it/s]
TESTING: 100%|██████████| 147/147 [00:05<00:00, 24.84it/s]


Train Acc: 0.4589 | Train Loss: 1.3721 | Test Acc: 0.4598 | Test Loss: 1.3709381028097503

Epoch 16


TRAINIGN: 100%|██████████| 587/587 [00:31<00:00, 18.40it/s]
TESTING: 100%|██████████| 147/147 [00:06<00:00, 22.10it/s]


Train Acc: 0.4633 | Train Loss: 1.3543 | Test Acc: 0.4642 | Test Loss: 1.375197401663073

Epoch 17


TRAINIGN:  26%|██▌       | 154/587 [00:09<00:25, 16.97it/s]


KeyboardInterrupt: 