In [12]:
import torch
import torch.nn.functional as F
import torchvision.transforms as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import datetime
import os
from timeit import default_timer as timer
from typing import Tuple, List, Type, Dict, Any

ValueError: module functions cannot set METH_CLASS or METH_STATIC

##  Load and examine the data

In [None]:
with open('./geo_kaggle_data/index.pkl', 'rb') as f:
    data_index = pickle.load(f)
data_index[1:3]

In [None]:
dataframe = pd.DataFrame(data_index)

In [None]:
dataframe.head()

In [None]:
dataframe['observed_TCC'].describe()

In [None]:
plt.figure(figsize = (10, 10))
sns.histplot(dataframe, x = 'observed_TCC', hue = 'mission', multiple = 'stack')

We can see that the data is skewed. It might be useful to apply augmentation to all labels except 8.  

In [None]:
dataframe.groupby(by = 'mission').mean()['observed_TCC']

Check labels of mission ```AI49```, because all of them has label *8*, which is kinda sus.

In [None]:
dataframe.loc[(dataframe['mission'] == 'AI49'), 'observed_TCC'].describe()

So the whole data of ```AI49``` is annotated with *8*  
But on photos it seems to have label 0, 1 or 2, so the whole data of ```AI49``` is misslabeled. It has only 160 samples, hence we can drop it without dealing any serious damage to the model.

I'll use only columns ```jpg_filename```, ```mission```, ```observations_dt```, ```observed_TCC```, beacause others are no use for the model.

In [None]:
dataframe = dataframe[['jpg_filename', 'mission', 'observations_dt', 'observed_TCC']]

In [None]:
dataframe = dataframe.set_index('mission').drop(labels = 'AI49')

In [None]:
dataframe = dataframe.reset_index(drop = False)

In [None]:
dataframe

## Create custom dataset

In [None]:
class CustomDataset(torch.utils.data.Dataset):
    
    def __init__(self, annotations, root_dir, transforms = None):
        
        super().__init__()
        
        self.annotations = annotations
        self.root_dir = os.path.abspath(root_dir)
        self.transforms = transforms
        
    def __len__(self):
        
        return len(self.annotations)
        
    def __getitem__(self, index):
        
        img_path = os.path.join(self.root_dir, self.annotations.iloc[index]['mission'], 'snapshots', 'snapshots-'+str((self.annotations.iloc[index]['observations_dt']).date()), self.annotations.iloc[index]['jpg_filename'])
        image = plt.imread(img_path)
        label = torch.tensor(int(self.annotations.iloc[index]['observed_TCC']))
        
        if self.transforms:
            image = self.transforms(image)
    
            
        return (image, label)

In [None]:
transforms = tf.Compose([tf.ToPILImage(), tf.Resize([256, 256]), tf.RandomHorizontalFlip(), tf.RandomVerticalFlip(), tf.RandomRotation(20), tf.ToTensor()])

In [None]:
SkyData = CustomDataset(dataframe, root_dir = 'geo_kaggle_data', transforms = transforms)

In [None]:
SkyData[254]

In [None]:
fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(10, 10))
for i, row in enumerate(axes):
    for j, ax in enumerate(row):
        sample, label = SkyData[i*325+j*13122]
        ax.imshow(sample.cpu().numpy().transpose(1, 2, 0))
        ax.set_title('Label: {}'.format(int(label)))

In [None]:
def train_single_epoch(model : torch.nn.Module,
                       optimizer : torch.optim.Optimizer,
                       loss_function : torch.nn.Module,
                       data_loader : torch.utils.data.DataLoader):
    
    model.train()
    loss_total = 0
    
    for data in data_loader:
        
        X, y = data
        X, y = X.to(device), y.to(device)
        
        model.zero_grad()
        output = model(X)
        
        loss = loss_function(output, y)
        loss_total += loss
        loss.backward()
        
        optimizer.step()
    
    loss_avg = loss_total / len(data_loader.dataset)
    
    return loss_avg

In [None]:
@torch.no_grad()
def validate_single_epoch(model: torch.nn.Module,
                          loss_function: torch.nn.Module, 
                          data_loader: torch.utils.data.DataLoader):
    
    model.eval()
    loss_total = 0
    accuracy_total = 0
    
    for data in data_loader:
        
        X, y = data
        X, y = X.to(device), y.to(device)
        
        output = model(X)
        
        loss = loss_function(output, y)
        loss_total += loss
        
        y_pred = output.argmax(dim = 1, keepdim = True).to(device)
        accuracy_total += y_pred.eq(y.view_as(y_pred)).sum().item()
        
    loss_avg = loss_total / len(data_loader.dataset)
    accuracy_avg = 100.0 * accuracy_total / len(data_loader.dataset)
    
    return {'loss' : loss_avg, 'accuracy' : accuracy_avg}

In [None]:
def plot_learning_curves(loss_list, accuracy_list, best_epoch):
    """
    Plot loss evolution on training and dev sets and
    accuracy evolution on dev set
    """
    
    # Plot learning loss curve
    plt.plot(loss_list['train'], label = 'Training set')
    plt.plot(loss_list['valid'], label = 'Dev set')
    plt.axvline(best_epoch, color = 'r', ls = '--', label = 'Best model')
    plt.title('Loss evolution')
    plt.xlabel('epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()
    # Plot accuracy curve         
    plt.plot(accuracy_list, color = 'g', label = 'Dev set')
    plt.axvline(best_epoch, color = 'r', ls = '--', label = 'Best model')
    plt.title('Accuracy evolution on validation set')
    plt.xlabel('epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.show()

In [None]:
def train_model(model: torch.nn.Module, 
                train_dataset: torch.utils.data.Dataset,
                valid_dataset: torch.utils.data.Dataset,
                loss_function: torch.nn.Module = torch.nn.CrossEntropyLoss(),
                optimizer_class: Type[torch.optim.Optimizer] = torch.optim,
                optimizer_params: Dict = {},
                initial_lr = 0.01,
                lr_scheduler_class: Any = torch.optim.lr_scheduler.ReduceLROnPlateau,
                lr_scheduler_params: Dict = {},
                batch_size = 64,
                max_epochs = 1000,
                early_stopping_patience = 20, 
                best_model_root = './best_model.pth'):
    
    
    optimizer = torch.optim.Adam(model.parameters(), lr=initial_lr, **optimizer_params)
    lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_params)
    
    train_loader = torch.utils.data.DataLoader(train_dataset, shuffle=True, batch_size=batch_size, pin_memory = True, num_workers = 1)
    valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, num_workers = 1)

    best_valid_loss = None
    best_epoch = None
    loss_list = {'train' : list(), 'valid' : list()}
    accuracy_list = list()
    
    for epoch in range(max_epochs):
        
        print(f'Epoch {epoch}')
        
        start = timer()
        
        train_loss = train_single_epoch(model, optimizer, loss_function, train_loader)
        
        # Evaluate perfomance on the training set
        loss_list['train'].append(train_loss)
        
        # Evaluate perfomance on the cross-validation set
        valid_metrics = validate_single_epoch(model, loss_function, valid_loader)
        loss_list['valid'].append(valid_metrics['loss'])
        accuracy_list.append(valid_metrics['accuracy'])
        
        print('time:', timer() - start)
        print(f'Validation metrics: \n{valid_metrics}')

        lr_scheduler.step(valid_metrics['loss'])
        
        if best_valid_loss is None or best_valid_loss > valid_metrics['loss']:
            print(f'-----Best model yet, saving-----')
            best_valid_loss = valid_metrics['loss']
            best_epoch = epoch
            torch.save(model, best_model_root)
            
        if epoch - best_epoch > early_stopping_patience:
            print('Early stopping triggered')
            plot_learning_curves(loss_list, accuracy_list, best_epoch)
            return

## Create a model

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print('Using GPU', f'({torch.cuda.get_device_name()})')
else:
    device = torch.device('cpu')
    print('Using CPU')

In [None]:
class Net(torch.nn.Module):
    
    def __init__(self, 
                 input_resolution: Tuple[int, int] = (512, 512),
                 input_channels: int = 1, 
                 hidden_layer_features: List[int] = [256, 256, 256],
                 activation: Type[torch.nn.Module] = torch.nn.Tanh,
                 num_classes: int = 9):
        
        super().__init__()
        
        self.conv1 = torch.nn.Conv2d(3, 3, 32)
        self.conv2 = torch.nn.Conv2d(3, 3, 32, 4)
        self.conv3 = torch.nn.Conv2d(3, 3, 7, 7)
        
        self.fc1 = torch.nn.Linear(7*7*3, 9)
        
    def forward(self, X):
        
        X = self.conv1(X)
        X = F.relu(X)
        
        X = self.conv2(X)
        X = F.relu(X)
        
        X = self.conv3(X)
        X = F.relu(X)
        
        X = X.view(-1, 7*7*3)
        X = self.fc1(X)
        
        output = F.log_softmax(X, dim = 1)
        
        return output

In [None]:
model = Net()
model.to(device)
print(model)
print('Total number of trainable parameters', 
      sum(p.numel() for p in model.parameters() if p.requires_grad))

In [None]:
train_dataset, valid_dataset = torch.utils.data.random_split(SkyData, [len(SkyData) - 15000, 15000])

### Train

In [None]:
train_model(model, 
            train_dataset=train_dataset, 
            valid_dataset=valid_dataset, 
            loss_function=torch.nn.CrossEntropyLoss(), 
            initial_lr=0.01, max_epochs = 500, batch_size = 1024)