## Lab 3

### Part 1. Overfit it (1.5 points)

Будем работать с датасетом [Fashion-MNIST](https://github.com/zalandoresearch/fashion-mnist) (*hint: он доступен в torchvision*).

Ваша задача состоит в следующем:
1. Обучить сеть, которая покажет >= 0.92 test accuracy.
2. Пронаблюдать и продемонстрировать процесс переобучения сети с увеличением числа параметров (==нейронов) и/или числа слоев и продемонстрировать это наглядно (например, на графиках).
3. Попробовать частично справиться с переобучением с помощью подходящих приемов (Dropout/batchnorm/augmentation etc.)

*Примечание*: Пункты 2 и 3 взаимосвязаны, в п.3 Вам прелагается сделать полученную в п.2 сеть менее склонной к переобучению. Пункт 1 является независимым от пунктов 2 и 3.

### Часть 1. Обучить сеть, которая покажет >= 0.92 test accuracy.

In [1]:
import torch
import random
import numpy as np

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.backends.cudnn.deterministic = True

In [2]:
import torchvision.datasets

In [3]:
train = torchvision.datasets.FashionMNIST('./', download=True, train=True)
test = torchvision.datasets.FashionMNIST('./', download=True, train=False)

In [4]:
X_train = train.train_data
y_train = train.train_labels
X_test = test.test_data
y_test = test.test_labels



In [5]:
class LeNet5(torch.nn.Module):
    def __init__(self,
                 activation='relu',
                 pooling='max',
                 conv_size=5,
                 use_batch_norm=True):
        super(LeNet5, self).__init__()
        
        self.conv_size = conv_size
        self.use_batch_norm = use_batch_norm
        
        if activation == 'tanh':
            activation_function = torch.nn.Tanh()
        elif activation == 'relu':
            activation_function  = torch.nn.ReLU()
        else:
            raise NotImplementedError
            
        if pooling == 'avg':
            pooling_layer = torch.nn.AvgPool2d(kernel_size=2, stride=2)
        elif pooling == 'max':
            pooling_layer  = torch.nn.MaxPool2d(kernel_size=2, stride=2)
        else:
            raise NotImplementedError
        
        if conv_size == 5:
            self.conv1 = torch.nn.Conv2d(
                in_channels=1, out_channels=6, kernel_size=5, padding=2)
        elif conv_size == 3:
            self.conv1_1 = torch.nn.Conv2d(
                in_channels=1, out_channels=6, kernel_size=3, padding=1)
            self.conv1_2 = torch.nn.Conv2d(
                in_channels=6, out_channels=6, kernel_size=3, padding=1)
        else:
            raise NotImplementedError

        self.act1 = activation_function
        self.bn1 = torch.nn.BatchNorm2d(num_features=6)
        self.pool1 = pooling_layer
       
        if conv_size == 5:
            self.conv2 = self.conv2 = torch.nn.Conv2d(
                in_channels=6, out_channels=16, kernel_size=5, padding=0)
        elif conv_size == 3:
            self.conv2_1 = torch.nn.Conv2d(
                in_channels=6, out_channels=16, kernel_size=3, padding=0)
            self.conv2_2 = torch.nn.Conv2d(
                in_channels=16, out_channels=16, kernel_size=3, padding=0)
        else:
            raise NotImplementedError

        self.act2 = activation_function
        self.bn2 = torch.nn.BatchNorm2d(num_features=16)
        self.pool2 = pooling_layer
        
        self.fc1 = torch.nn.Linear(5 * 5 * 16, 120)
        self.act3 = activation_function
    
        self.fc2 = torch.nn.Linear(120, 84)
        self.act4 = activation_function
        
        self.fc3 = torch.nn.Linear(84, 10)
    
    def forward(self, x):
        if self.conv_size == 5:
            x = self.conv1(x)
        elif self.conv_size == 3:
            x = self.conv1_2(self.conv1_1(x))
        x = self.act1(x)
        if self.use_batch_norm:
            x = self.bn1(x)
        x = self.pool1(x)
        
        if self.conv_size == 5:
            x = self.conv2(x)
        elif self.conv_size == 3:
            x = self.conv2_2(self.conv2_1(x))
        x = self.act2(x)
        if self.use_batch_norm:
            x = self.bn2(x)
        x = self.pool2(x)
        
        x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3))
        x = self.fc1(x)
        x = self.act3(x)
        x = self.fc2(x)
        x = self.act4(x)
        x = self.fc3(x)
        
        return x

In [6]:
lenet5 = LeNet5()

In [7]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
lenet5 = lenet5.to(device)

In [8]:
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(lenet5.parameters(), lr=1.0e-3)

In [9]:
from sklearn.model_selection import train_test_split


X_test = X_test.unsqueeze(1).float()

#X_train, X_val, y_train, y_val = train_test_split(X_train, 
#                                                    y_train, 
#                                                    test_size=0.1, 
#                                                    random_state=42)

X_train = X_train.unsqueeze(1).float()
#X_val = X_val.unsqueeze(1).float()

In [10]:
batch_size = 64

test_accuracy_history = []
test_loss_history = []

X_test = X_test.to(device)
y_test = y_test.to(device)


#X_val = X_val.to(device)
#y_val = y_val.to(device)



for epoch in range(20):
    order = np.random.permutation(len(X_train))
    for start_index in range(0, len(X_train), batch_size):
        optimizer.zero_grad()
        
        batch_indexes = order[start_index:start_index+batch_size]
        
        X_batch = X_train[batch_indexes].to(device)
        y_batch = y_train[batch_indexes].to(device)
        
        preds = lenet5.forward(X_batch) 
        
        loss_value = loss(preds, y_batch)
        loss_value.backward()
        
        optimizer.step()
        
    test_preds = lenet5.forward(X_test)
    test_loss_history.append(loss(test_preds, y_test).data.cpu())
    
    accuracy = float((test_preds.argmax(dim=1) == y_test).float().mean().data.cpu())
    test_accuracy_history.append(accuracy)
    
    #val_preds = lenet5.forward(X_val)    
    #accuracy_val = float((val_preds.argmax(dim=1) == y_val).float().mean().data.cpu())
    
    print('Epoch:', epoch, '        acc_train:', np.round(accuracy, 4))

Epoch: 0         acc_train: 0.8737
Epoch: 1         acc_train: 0.8945
Epoch: 2         acc_train: 0.8934
Epoch: 3         acc_train: 0.9031
Epoch: 4         acc_train: 0.9007
Epoch: 5         acc_train: 0.905
Epoch: 6         acc_train: 0.9036
Epoch: 7         acc_train: 0.8975
Epoch: 8         acc_train: 0.9075
Epoch: 9         acc_train: 0.9038
Epoch: 10         acc_train: 0.9056
Epoch: 11         acc_train: 0.9022
Epoch: 12         acc_train: 0.9027
Epoch: 13         acc_train: 0.9024
Epoch: 14         acc_train: 0.8962
Epoch: 15         acc_train: 0.8979
Epoch: 16         acc_train: 0.9
Epoch: 17         acc_train: 0.9002
Epoch: 18         acc_train: 0.9033
Epoch: 19         acc_train: 0.8966
