In [None]:
import os
import sys
from pathlib import Path
import random
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Accuracy
from torchvision import models, transforms
from torchsummary import summary
import wandb

In [None]:
def get_abs_path(n_parent: int = 0):
    return Path('../' * n_parent).resolve()

In [None]:
class CountryDataset(Dataset):
    def __init__(self, images_paths, class_labels, transforms):
        self.images_paths = images_paths
        self.transforms = transforms
        self.labels = []
        for image_path in self.images_paths:
            label = get_label_from_path(class_labels, image_path)
            self.labels.append(label)

    def __len__(self):
        return len(self.images_paths)

    def __getitem__(self, idx):
        path = self.images_paths[idx]
        x = Image.open(path)
        x = self.transforms(x)
        label = self.labels[idx]
        return x, label, path


class CountryClassificator(nn.Module):

    def __init__(self, num_classes):
        super(CountryClassificator, self).__init__()

        self.model = models.resnet18(pretrained=True)
        self.model.fc = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(self.model.fc.in_features, num_classes),
        )

    # @torch.cuda.amp.autocast() # ampere architecture
    def forward(self, x):
        x = self.model(x)
        return x

In [None]:
HP = {
    'epochs': 15,
    'batch_size': 16,
    'learning_rate': 1.5e-3,
    'momentum': 0.9,
    'num_workers': 4
}

wandb.init(
    entity='konradszafer',
    project='deep-geo-guessr',
    name='run 4',
    notes='''
        focal loss, sgd,
        uncleaned data, random horizontal flip, random perspective,
    ''',
    config=HP
)

In [None]:
if 'ipykernel' in sys.modules:
    print('Training using notebook')
    HP['num_workers'] = 0
else:
    print('Training using terminal')
    HP['num_workers'] = 8

In [None]:
path = get_abs_path(1)
model_path = path / 'models' / 'deep_geo_guessr.pt'
data_path = path / 'data'
images_paths = data_path.glob('**/*.png')
images_paths = list(images_paths)
images_paths = [str(path) for path in images_paths]

class_names = [d.name for d in data_path.iterdir() if d.is_dir()]
class_labels = {value:key for (key,value) in enumerate(class_names)}
print('Labels:', class_labels)

train_size = 0.8
val_size = 0.1
test_size = 0.1
train_paths, val_paths, test_paths = [], [], []

for class_name, label in class_labels.items():
    class_list = []
    for image_path in images_paths:
        if class_name in image_path:
            class_list.append(image_path)
    train_count = int(train_size * len(class_list))
    val_count = int(val_size * len(class_list))
    test_count = len(class_list) - train_count - val_count
    train_paths.extend(class_list[0:train_count])
    val_paths.extend(class_list[train_count:train_count+val_count])
    test_paths.extend(class_list[train_count+val_count:])

wandb.config.update({'train_samples': len(train_paths)})
wandb.config.update({'valid_samples': len(val_paths)})
print('train len: %d val len: %d test len: %d' % (len(train_paths), len(val_paths), len(test_paths)))

In [None]:
data_transforms = {
    'train': transforms.Compose([
        # transforms.Resize((400, 400)),
        # transforms.RandomCrop((300, 300)),
        transforms.RandomChoice([
            transforms.RandomHorizontalFlip(),
            # transforms.RandomVerticalFlip(),
            transforms.RandomPerspective(),
            # transforms.RandomRotation(45),
        ]),
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'validate': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}


def get_label_from_path(class_labels, path):
    for name, label in class_labels.items():
        if name in path:
            return label


batch_size = HP['batch_size']
num_workers = HP['num_workers'] # gpus * 4, in notebook 0
train_dataset = CountryDataset(train_paths, class_labels, data_transforms['train'])
valid_dataset = CountryDataset(val_paths, class_labels, data_transforms['validate'])
test_dataset = CountryDataset(test_paths, class_labels, data_transforms['test'])
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)

In [None]:
for _ in range(3):
    idx = random.randint(0, len(train_dataset))
    x, label, filename = train_dataset[idx]
    name = class_names[label]
    img = mpimg.imread(filename)
    plt.gca().clear()
    plt.title(name)
    plt.imshow(img)
    plt.show()

In [None]:
device = ('cuda' if torch.cuda.is_available() else 'cpu')
print('Using {0} device'.format(device))

model = CountryClassificator(num_classes=len(class_labels))
# model.load_state_dict(torch.load(model_path))
model.to(device)

summary(model, (3, 224, 224))

In [None]:
@torch.jit.script
def get_true_preds_count(preds, labels):
    _, pred_label = torch.max(preds, 1)
    return pred_label.eq(labels).sum().item()

def cross_entropy(pred, target):
    return F.cross_entropy( pred, target,
                            weight=torch.Tensor([1.0, 1.0, 1.0, 1.0, 1.0]).to(device),
                            reduction='mean')

@torch.jit.script
def focal_loss(pred, target, gamma: float=2.0): # weight=None
    ce_loss = F.cross_entropy(pred, target, reduction='mean') # weight=weight,
    p_t = torch.exp(-ce_loss)
    loss = (1 - p_t)**gamma * ce_loss
    return loss

def train(model, optimizer, loss_function, train_dataloader, valid_dataloader, epochs, batch_size):

    for epoch in range(epochs):
        total_train_loss = 0
        total_valid_true = 0
        total_valid_loss = 0
        total_train_true = 0

        i = 0
        model.train()
        for x, labels, _ in train_dataloader:
            x = x.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            preds = model(x)
            loss = loss_function(preds, labels)
            total_train_loss += loss.item()

            loss.backward()
            # mimic bigger batch size
            if i % 1 == 0 or (i+1) == len(train_dataloader):
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)
            i += 1

            total_train_true += get_true_preds_count(preds, labels)

        with torch.no_grad():
            for x, labels, _ in valid_dataloader:
                x, labels = x.to(device), labels.to(device)
                preds = model(x)
                loss = loss_function(preds, labels)
                total_valid_loss += loss.item()

                total_valid_true += get_true_preds_count(preds, labels)

        train_samples_count = len(train_dataloader) * batch_size
        valid_samples_count = len(valid_dataloader) * batch_size
        print( (f'Epoch {epoch+1}/{epochs} '
                f'Train loss: {total_train_loss/train_samples_count:.4f} '
                f'Train acc: {total_train_true/train_samples_count:.3f} '
                f'Valid loss: {total_valid_loss/valid_samples_count:.4f} '
                f'Valid acc: {total_valid_true/valid_samples_count:.3f} '))

        wandb.log({'train loss': total_train_loss/train_samples_count})
        wandb.log({'valid loss': total_valid_loss/valid_samples_count})
        wandb.log({'train acc': total_train_true/train_samples_count})
        wandb.log({'valid acc': total_valid_true/valid_samples_count})
        wandb.watch(model)


if device == 'cuda':
    torch.backends.cudnn.benchmark = True


wandb.watch(model)
optimizer = torch.optim.SGD(model.parameters(), lr=HP['learning_rate'], momentum=HP['learning_rate'])
train(model, optimizer, focal_loss, train_dataloader, valid_dataloader, HP['epochs'], HP['batch_size'])

In [None]:
def test(model, dataset):
    total_true = 0
    true = []
    predictions = []
    model.eval()
    for x, labels, _ in dataset:
        x = x.to(device)
        preds = model(x.unsqueeze(0))
        _, pred_label = torch.max(preds, 1)
        total_true += pred_label.eq(labels).sum().item()
        true.append(labels)
        predictions.append(pred_label.detach().cpu().item())

    print(f'Accuracy: {total_true/len(dataset):.3f}')

    cm = confusion_matrix(true, predictions)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
    disp.plot()
    plt.figure(figsize=(8,8))
    plt.show()


test(model, test_dataset)

In [None]:
def plot_predictions(model, dataset, class_names, count):
    model.eval()
    for i in range(count):
        x, label, filename = dataset[i]
        x = x.to(device)
        preds = model(x.unsqueeze(0))
        preds = nn.Softmax(dim=1)(preds)

        print('File path:', filename)
        for i in range(preds.shape[1]):
            class_name = class_names[i]
            confidence = preds[:, i].item()
            print(class_name, str(int(confidence * 100))+'%', end=' ')

        img = mpimg.imread(filename)
        plt.gca().clear()
        plt.title(class_names[label])
        plt.imshow(img)
        plt.show()


plot_predictions(model, test_dataset, class_names, 10)

In [None]:
# torch.save(model.state_dict(), model_path)