In [1]:
from torchvision import transforms 
from PIL import Image
import torch
import os
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
import random
from torch.utils.data import Dataset, DataLoader
from torch import nn

In [2]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(0)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cpu')

In [3]:
root_dir = "../datasets/waterbird_augmented/resnet152_waterbird_features"

In [5]:
class WaterbirdDatasetFeatures(Dataset):
    def __init__(self, dataset_path, transform=None):
        self.dataset = np.load(dataset_path, allow_pickle=True).item()
        
    
    def __len__(self):
        return len(self.dataset['features'])
    
    
    def __getitem__(self, idx):
        feature = self.dataset['features'][idx]        
        label = self.dataset['labels'][idx]
            
        return feature, label

In [6]:
batch_size = 64
saving_dir = "./trained"

samples = '800_sample_1'
train_filename = f'train_data_{samples}.npy'
train_set = WaterbirdDatasetFeatures(dataset_path=os.path.join(root_dir, train_filename))
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True)

val_filename = 'val_data.npy'
val_set = WaterbirdDatasetFeatures(dataset_path=os.path.join(root_dir, val_filename))
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, pin_memory=True)

In [7]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.linear = nn.Sequential(nn.Linear(in_features=2048, out_features=1024, bias=False),
                                    nn.BatchNorm1d(num_features=1024),
                                    nn.ReLU(),
                                    nn.Dropout(0.5, inplace=False),
                                    nn.Linear(in_features=1024, out_features=512, bias=False),
                                    nn.BatchNorm1d(num_features=512),
                                    nn.ReLU(),
                                    nn.Dropout(0.5, inplace=False),
                                    nn.Linear(in_features=512, out_features=1, bias=True),
                                    nn.Sigmoid())

    def forward(self, x):
        out = self.linear(x)
        return out

In [8]:
model = MLP().to(DEVICE)

In [9]:
learning_rate = 0.00001
num_epoch = 20

criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)

In [10]:
train_loss = []
val_loss = []
train_acc = []
val_acc = []
min_loss = float('inf')


for epoch in range(num_epoch):

    total_loss = 0
    corrects = 0
    
    # training loop
    for features, labels in tqdm(train_loader, desc=f'Train epoch: {epoch + 1}'):
        features = features.to(DEVICE)
        labels = labels.view(-1, 1).to(torch.float32).to(DEVICE)
        
        out = model(features)
        loss = criterion(out, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        preds = (out >= 0.5).to(torch.float32)
        corrects += torch.sum((preds == labels)).item()
    
    accuracy = corrects / len(train_loader.dataset)
    epoch_loss = total_loss / len(train_loader)
    train_loss.append(epoch_loss)
    train_acc.append(accuracy)
    print(f'Epoch: {epoch + 1} | Loss: {epoch_loss:.4f} | Accuracy: {accuracy:.2f}')

    # validation loop
    total_loss = 0
    corrects = 0
    with torch.no_grad():
        for features, labels in tqdm(val_loader, desc=f'Validation epoch: {epoch + 1}'):
               
            features = features.to(DEVICE)
            labels = labels.view(-1, 1).to(torch.float32).to(DEVICE)
            
            out = model(features)
            
            loss = criterion(out, labels)
            total_loss += loss.item()
            
            preds = (out >= 0.5).to(torch.float32)
            corrects += torch.sum((preds == labels)).item()
        
        accuracy = corrects / len(val_loader.dataset)
        epoch_loss = total_loss / len(val_loader)
        val_loss.append(epoch_loss)
        val_acc.append(accuracy)
        print(f'Validation loss: {epoch_loss:.4f} | Accuracy: {accuracy:.2f}')

        if epoch_loss < min_loss:
            min_loss = epoch_loss
            torch.save(model.state_dict(), os.path.join(saving_dir, f"resnet150_augmented_{samples}.pth"))

Train epoch: 1:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch: 1 | Loss: 0.6596 | Accuracy: 0.61


Validation epoch: 1:   0%|          | 0/19 [00:00<?, ?it/s]

Validation loss: 0.6405 | Accuracy: 0.62


Train epoch: 2:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch: 2 | Loss: 0.5998 | Accuracy: 0.68


Validation epoch: 2:   0%|          | 0/19 [00:00<?, ?it/s]

Validation loss: 0.6096 | Accuracy: 0.68


Train epoch: 3:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch: 3 | Loss: 0.5553 | Accuracy: 0.73


Validation epoch: 3:   0%|          | 0/19 [00:00<?, ?it/s]

Validation loss: 0.5873 | Accuracy: 0.69


Train epoch: 4:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch: 4 | Loss: 0.5237 | Accuracy: 0.75


Validation epoch: 4:   0%|          | 0/19 [00:00<?, ?it/s]

Validation loss: 0.5828 | Accuracy: 0.70


Train epoch: 5:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch: 5 | Loss: 0.4930 | Accuracy: 0.76


Validation epoch: 5:   0%|          | 0/19 [00:00<?, ?it/s]

Validation loss: 0.5780 | Accuracy: 0.70


Train epoch: 6:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch: 6 | Loss: 0.4728 | Accuracy: 0.78


Validation epoch: 6:   0%|          | 0/19 [00:00<?, ?it/s]

Validation loss: 0.5689 | Accuracy: 0.71


Train epoch: 7:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch: 7 | Loss: 0.4563 | Accuracy: 0.79


Validation epoch: 7:   0%|          | 0/19 [00:00<?, ?it/s]

Validation loss: 0.5710 | Accuracy: 0.71


Train epoch: 8:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch: 8 | Loss: 0.4362 | Accuracy: 0.80


Validation epoch: 8:   0%|          | 0/19 [00:00<?, ?it/s]

Validation loss: 0.5577 | Accuracy: 0.72


Train epoch: 9:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch: 9 | Loss: 0.4226 | Accuracy: 0.81


Validation epoch: 9:   0%|          | 0/19 [00:00<?, ?it/s]

Validation loss: 0.5639 | Accuracy: 0.72


Train epoch: 10:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch: 10 | Loss: 0.4069 | Accuracy: 0.82


Validation epoch: 10:   0%|          | 0/19 [00:00<?, ?it/s]

Validation loss: 0.5727 | Accuracy: 0.72


Train epoch: 11:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch: 11 | Loss: 0.3941 | Accuracy: 0.83


Validation epoch: 11:   0%|          | 0/19 [00:00<?, ?it/s]

Validation loss: 0.5587 | Accuracy: 0.73


Train epoch: 12:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch: 12 | Loss: 0.3807 | Accuracy: 0.83


Validation epoch: 12:   0%|          | 0/19 [00:00<?, ?it/s]

Validation loss: 0.5526 | Accuracy: 0.73


Train epoch: 13:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch: 13 | Loss: 0.3623 | Accuracy: 0.85


Validation epoch: 13:   0%|          | 0/19 [00:00<?, ?it/s]

Validation loss: 0.5591 | Accuracy: 0.74


Train epoch: 14:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch: 14 | Loss: 0.3475 | Accuracy: 0.86


Validation epoch: 14:   0%|          | 0/19 [00:00<?, ?it/s]

Validation loss: 0.5662 | Accuracy: 0.73


Train epoch: 15:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch: 15 | Loss: 0.3381 | Accuracy: 0.86


Validation epoch: 15:   0%|          | 0/19 [00:00<?, ?it/s]

Validation loss: 0.5679 | Accuracy: 0.74


Train epoch: 16:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch: 16 | Loss: 0.3253 | Accuracy: 0.87


Validation epoch: 16:   0%|          | 0/19 [00:00<?, ?it/s]

Validation loss: 0.5691 | Accuracy: 0.72


Train epoch: 17:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch: 17 | Loss: 0.3186 | Accuracy: 0.87


Validation epoch: 17:   0%|          | 0/19 [00:00<?, ?it/s]

Validation loss: 0.5660 | Accuracy: 0.73


Train epoch: 18:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch: 18 | Loss: 0.3067 | Accuracy: 0.88


Validation epoch: 18:   0%|          | 0/19 [00:00<?, ?it/s]

Validation loss: 0.5768 | Accuracy: 0.73


Train epoch: 19:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch: 19 | Loss: 0.2914 | Accuracy: 0.89


Validation epoch: 19:   0%|          | 0/19 [00:00<?, ?it/s]

Validation loss: 0.5820 | Accuracy: 0.73


Train epoch: 20:   0%|          | 0/125 [00:00<?, ?it/s]

Epoch: 20 | Loss: 0.2823 | Accuracy: 0.89


Validation epoch: 20:   0%|          | 0/19 [00:00<?, ?it/s]

Validation loss: 0.5869 | Accuracy: 0.74


In [11]:
train_statistics = {
    'train_loss': train_loss,
    'val_loss': val_loss,
    'train_acc': train_acc,
    'val_acc': val_acc
}
np.save(os.path.join(saving_dir, f'train_statistics_{samples}.npy'), train_loss)