In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
from torchsummary import summary
import os
import time

from dataset import HackathonDataset
from convnet import ConvNet
from resnet import ResNet

from config import DATA_DIR, DEVICE, USE_RAW, AUTO_ROTATE

In [2]:
class Ensemble:
    
    def __init__(self, Model, device, n_estimators):
        self.Model = Model
        self.instances = [self.Model(device) for i in range(n_estimators)]
        self.performances = []
    
    def fit(self, train_dataloader, test_dataloader, n_epochs, print_frequency):
        for it, instance in enumerate(self.instances):
            print(f"\n=== Training instance {it+1}/{len(self.instances)} ===\n")
            score = instance.fit(train_dataloader, test_dataloader, n_epochs, print_frequency)
            self.performances.append(score)
    
    def predict(self, dataloader):
        sorted_instances = [instance for _,instance in sorted(zip(self.performances,self.instances))]
        predictions = [instance.predict(dataloader) for instance in sorted_instances[0:int(len(sorted_instances)*0.8)]]
        return np.mean(predictions, axis=0)

In [6]:
n_epochs = 4
n_estimators = 10
print_frequency = 3
batch_size = 8  # High batch size often happen to not converge... So we use small batches, even if slower
pred_batch_size = 128  # There is no problem of convergence for training batch size

In [7]:
#========================NOTE============================
# We often have to reset the model, because it won't converge. I don't know why, but it is useful to know
# If the training loss is stuck around 22 and the validation loss is stuck around 10,
# reset the model by running this cell again, and relaunch training
#========================END OF NOTE=====================

dataset = HackathonDataset(DATA_DIR + 'mixed_train.csv', DATA_DIR, USE_RAW, transform=True, auto_rotate=AUTO_ROTATE)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=os.cpu_count() - 2)
val_dataset = HackathonDataset(DATA_DIR + 'mixed_validation.csv', DATA_DIR, USE_RAW, auto_rotate=AUTO_ROTATE)
val_dataloader = DataLoader(val_dataset, batch_size=pred_batch_size, shuffle=False, num_workers=os.cpu_count() - 2)
model = Ensemble(ConvNet, DEVICE, n_estimators)

In [None]:
model.fit(dataloader, val_dataloader, n_epochs, print_frequency)


=== Training instance 1/10 ===

Epoch 1/4
Number of batches viewed : 2373
Current training loss : 8.573170372669173
Current validation loss : 6.6105760082485165
Number of batches viewed : 4747
Current training loss : 7.51407811997212
Current validation loss : 5.36526443075946
Number of batches viewed : 7121
Current training loss : 7.041330185273965
Current validation loss : 4.563555572915265
The epoch took  36.25 seconds
Epoch 2/4
Number of batches viewed : 2373
Current training loss : 6.847223588219187
Current validation loss : 5.02344570948383
Number of batches viewed : 4747
Current training loss : 6.756312635935147
Current validation loss : 4.298437840356602
Number of batches viewed : 7121
Current training loss : 6.550409719636687
Current validation loss : 6.6702773308190775
The epoch took  36.52 seconds
Epoch 3/4
Number of batches viewed : 2373
Current training loss : 6.475281699776348
Current validation loss : 4.516090115224283
Number of batches viewed : 4747
Current training los

# Evaluation on Test Data

In [10]:
test_dataset = HackathonDataset(DATA_DIR + 'mixed_test.csv', DATA_DIR, USE_RAW, auto_rotate=AUTO_ROTATE)
test_dataloader = DataLoader(test_dataset, batch_size=pred_batch_size, shuffle=False, num_workers=os.cpu_count() - 2)

In [11]:
image_file_names = []
for val in test_dataloader:
    image_file_names += val['image_file_name']

predictions = model.predict(test_dataloader)
kaggle_df = pd.DataFrame({'image_id': image_file_names,
                          'predicted_z': predictions})

In [12]:
kaggle_df.to_csv('predictions/prediction-' + datetime.now().strftime("%d-%m-%y:%H-%M") + '.csv', index=False)