In [7]:
import numpy as np
import torch
import random
import torch.nn as nn
import torch.optim as optim
from torchmetrics.classification import Accuracy
from tqdm import tqdm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader,Subset,random_split
from vit_pytorch import ViT
from vit_pytorch.vit import Transformer, Attention

### Data Upload

In [8]:
data_path = "filtered_species"
batch_size = 64 

transform = transforms.Compose([
    transforms.CenterCrop((256, 256)),  #Resize to minimum of all sizes - Will update size in cnn architecture
    transforms.RandomHorizontalFlip(p = 0.25),
    transforms.RandomRotation(degrees = 30),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])

data = datasets.ImageFolder(root=data_path, transform=transform)

subset_size = 1000
indices = random.sample(range(len(data)), subset_size)
data_subset = Subset(data, indices)
train_size = int(0.8 * len(data))
test_size = len(data) - train_size

#change data_subset -> data for train from whole data
training, testing = random_split(data, [train_size, test_size], generator=torch.Generator().manual_seed(1111))

training_dataset = DataLoader(training, batch_size=batch_size, shuffle=True)
testing_dataset = DataLoader(testing, batch_size=batch_size, shuffle=True)
#For testing purposes
#print("Class names:", data.classes)
#len(data.classes)

### Classifier Architecture

In [9]:
class BirdClassifier(nn.Module):
    def __init__(self,
                 cnn_state=False,                # Whether to use CNN before ViT
                 image_size=256,
                 patch_size=16,
                 num_class=22,
                 dim=128,
                 layer_count=1,
                 head_count=1,
                 transformer_ff_neurons=256,
                 transformer_dropout=0.1):
        super().__init__()
        
        self.cnn_state = cnn_state
        self.image_size = image_size if not cnn_state else 64  # Will update if CNN is used

        if cnn_state:
            self.cnn = nn.Sequential(
                nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),  # AlexNet-style
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=3, stride=2),
                
                nn.Conv2d(96, 256, kernel_size=5, padding=2),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=3, stride=2),

                nn.Conv2d(256, 384, kernel_size=3, padding=1),
                nn.ReLU(),

                nn.Conv2d(384, 384, kernel_size=3, padding=1),
                nn.ReLU(),

                nn.Conv2d(384, 256, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.AdaptiveAvgPool2d((self.image_size, self.image_size))  # Resize to ViT input
            )

        self.vision_transformer = ViT(
            image_size=self.image_size,
            patch_size=patch_size,
            num_classes=num_class,
            dim=dim,
            depth=layer_count,
            heads=head_count,
            mlp_dim=transformer_ff_neurons,
            dropout=transformer_dropout,
            emb_dropout=transformer_dropout,
            channels=3
        )

    def forward(self, x):
        if self.cnn_state:
            x = self.cnn(x)
        x = self.vision_transformer(x)
        return x

    def print_config(self):
        print(f"Using CNN: {self.cnn_state}")
        print(f"ViT dim: {self.vision_transformer.dim}, layers: {self.vision_transformer.depth}")




In [10]:
def train_model(model, dataloader, criterion, optimizer_metric, accuracy_metric, device):
    model.train()
    net_loss = 0
    for images, labels in tqdm(dataloader, desc = "TRAINIGN"):
        images, labels = images.to(device), labels.to(device)
        y_hat = model(images)
        loss = criterion(y_hat, labels)
        optimizer_metric.zero_grad()
        loss.backward()
        optimizer_metric.step()
        accuracy_metric.update(y_hat, labels)
        net_loss += loss.item()
    
    epoch_accuracy = accuracy_metric.compute().item()
    epoch_loss = net_loss/(len(dataloader))

    return epoch_accuracy, epoch_loss

def test_model(model, dataloader, criterion, accuracy_metric, device):
    model.eval()
    net_loss = 0
    for images, labels in tqdm(dataloader, desc = "TESTING"):
        image, label = images.to(device), labels.to(device)
        y_hat = model(image)
        loss = criterion(y_hat, label)
        accuracy_metric.update(y_hat, label)
        net_loss += loss.item()
    
    epoch_accuracy = accuracy_metric.compute().item()
    epoch_loss = net_loss/(len(dataloader))

    return epoch_accuracy, epoch_loss

In [None]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = BirdClassifier(cnn_state=False).to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
accuracy_score = Accuracy(task = 'multiclass', num_classes = 22).to(device)


patience = 3
min_delta = 1e-4
best_accuracy = 0
counter = 0

loss_scores_train = []
accuracy_scores_train = []

for epoch in range(50):  
    print(f"\nEpoch {epoch+1}")
    train_acc, train_loss = train_model(model, training_dataset, criterion, optimizer, accuracy_score, device)
    
    accuracy_scores_train.append(train_acc)
    loss_scores_train.append(train_loss)

    print(f"Train Acc: {train_acc:.4f} | Train Loss: {train_loss:.4f}")

    if train_acc - best_accuracy > min_delta:
        best_accuracy = train_acc
        counter = 0
    else:
        counter += 1
        print(f"No improvement. Early stopping counter: {counter}/{patience}")
        if counter >= patience:
            print("Early stopping triggered.")
            break


Epoch 1


TRAINIGN:   6%|▌         | 80/1336 [05:08<1:20:19,  3.84s/it]

In [None]:
loss_scores_train


[2.923904274965841]

In [None]:
accuracy_scores_train

[0.11756100505590439]

In [None]:
torch.save(model.state_dict(), "bird_classifier.pth")