In [None]:
import os
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from PIL import Image
from sklearn.model_selection import train_test_split
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Accuracy
from torchvision import models, transforms
from torchsummary import summary

In [None]:
def get_abs_path(n_parent: int = 0):
    return Path('../' * n_parent).resolve()

In [None]:
path = get_abs_path(1)
model_path = path / 'models' / 'deep_geo_guessr.pt'
data_path = path / 'data'
images_paths = data_path.glob('**/*.png')
images_paths = list(images_paths)
images_paths = [str(path) for path in images_paths]

class_names = [d.name for d in data_path.iterdir() if d.is_dir()]
class_labels = {value:key for (key,value) in enumerate(class_names)}
print('Labels:', class_labels)

train_paths, test_paths = train_test_split(images_paths, test_size=0.2, shuffle=True)
val_paths, test_paths = train_test_split(test_paths, test_size=0.5, shuffle=True)
print('train len: %d val len: %d test len: %d' % (len(train_paths), len(val_paths), len(test_paths)))

In [None]:
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((400, 400)),
        transforms.RandomCrop((300, 300)),
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'validate': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}


def get_label_from_path(class_labels, path):
    for name, label in class_labels.items():
        if name in path:
            return label


class CountryDataset(Dataset):
    def __init__(self, images_paths, class_labels, transforms):
        self.images_paths = images_paths
        self.transforms = transforms
        self.labels = []
        for image_path in self.images_paths:
            label = get_label_from_path(class_labels, image_path)
            self.labels.append(label)

    def __len__(self):
        return len(self.images_paths)

    def __getitem__(self, idx):
        path = self.images_paths[idx]
        x = Image.open(path)
        x = self.transforms(x)
        label = self.labels[idx]
        return x, label, path


batch_size = 10
num_workers = 0
train_dataset = CountryDataset(train_paths, class_labels, data_transforms['train'])
valid_dataset = CountryDataset(val_paths, class_labels, data_transforms['validate'])
test_dataset = CountryDataset(test_paths, class_labels, data_transforms['test'])
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

In [None]:
for i in range(3):
    x, label, filename = train_dataset[i]
    name = class_names[label]
    img = mpimg.imread(filename)
    plt.gca().clear()
    plt.title(name)
    plt.imshow(img)
    plt.show()

In [None]:
class CountryClassificator(nn.Module):

    def __init__(self, num_classes):
        super(CountryClassificator, self).__init__()

        self.model = models.resnet18(pretrained=True)
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)

    def forward(self, x):
        x = self.model(x)
        return x

In [None]:
device = ('cuda' if torch.cuda.is_available() else 'cpu')
print('Using {0} device'.format(device))

model = CountryClassificator(num_classes=len(class_labels))
model.to(device)
summary(model, (3, 224, 224))

In [None]:
def train(model, optimizer, loss_function, train_dataloader, valid_dataloader, epochs, batch_size):

    for epoch in range(epochs):
        total_train_loss = 0
        total_valid_true = 0
        total_valid_loss = 0
        total_train_true = 0

        model.train()
        for x, labels, _ in train_dataloader:
            x, labels = x.to(device), labels.to(device)
            preds = model(x)
            loss = loss_function(preds, labels)
            total_train_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            _, pred_label = torch.max(preds, 1)
            total_train_true += pred_label.eq(labels).sum().item()

        model.eval()
        for x, labels, _ in valid_dataloader:
            x, labels = x.to(device), labels.to(device)
            preds = model(x)
            loss = loss_function(preds, labels)
            total_valid_loss += loss.item()

            _, pred_label = torch.max(preds, 1)
            total_valid_true += pred_label.eq(labels).sum().item()

        train_samples_count = len(train_dataloader) * batch_size
        valid_samples_count = len(valid_dataloader) * batch_size
        print( (f'Epoch {epoch+1}/{epochs} '
                f'Train loss: {total_train_loss/train_samples_count:.3f} '
                f'Train acc: {total_train_true/train_samples_count:.3f} '
                f'Valid loss: {total_valid_loss/valid_samples_count:.3f} '
                f'Valid acc: {total_valid_true/valid_samples_count:.3f} '))


def loss_function(pred, target):
    return F.cross_entropy( pred, target,
                            weight=torch.Tensor([1.0, 1.0, 1.0, 1.0, 1.0]).to(device),
                            reduction='mean')


epochs = 15
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
train(model, optimizer, loss_function, train_dataloader, valid_dataloader, epochs, batch_size)

In [None]:
def test(model, dataset):
    total_true = 0
    model.eval()
    for x, labels, _ in dataset:
        x = x.to(device)
        preds = model(x.unsqueeze(0))
        _, pred_label = torch.max(preds, 1)
        total_true += pred_label.eq(labels).sum().item()
    print(f'Accuracy: {total_true/len(dataset):.3f}')


test(model, test_dataset)

In [None]:
torch.save(model.state_dict(), model_path)