In [30]:
# Numerical Operations
import math
import numpy as np

# Reading/Writing Data
import pandas as pd
import os
import csv

# For Progress Bar
from tqdm import tqdm

# Pytorch
import torch 
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

# For plotting learning curve
from torch.utils.tensorboard import SummaryWriter

from sklearn.model_selection import KFold

In [31]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
kfold_n = 5

In [32]:
def same_seed(seed): 
    '''Fixes random number generator seeds for reproducibility.'''
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

In [33]:
def kFold(data_set):
    train_k = []
    valid_k = []
    
    KF = KFold(n_splits = kfold_n)
    for train_index, valid_index in KF.split(data_set):
        train_set, valid_set = data_set[train_index], data_set[valid_index]
        train_k.append(np.array(train_set))
        valid_k.append(np.array(valid_set))
    
    return np.array(train_k), np.array(valid_k)

In [34]:
class CustomDataset(Dataset):
    def __init__(self, x, y=None):
        if y is None:
            self.y_data = y
        else:
            self.y_data = torch.FloatTensor(y)
        self.x_data = torch.FloatTensor(x)

    def __getitem__(self, idx):
        if self.y_data is None:
            return self.x_data[idx]
        else:
            return self.x_data[idx], self.y_data[idx]

    def __len__(self):
        return len(self.x_data)

In [35]:
class My_Model(nn.Module):
    def __init__(self, input_dim):
        super(My_Model, self).__init__()
        # making the model more elaborative helps reducing the loss
        # trial 1: set Linear(), input_dim -> 16 -> 8 -> 1
        # result: could generate the result, but the loss is not reduced satisfactorilly
        
        # trial 2: 128, and reduce by /2
        # result: at some folds, the train loss could not be decreased,
        #         as well as the validation loss.
        #         they are overfitted, so reduce the to 64
        
        # trial 3: 64 -> 16 -> then reduce by 2
        # result: it produces the most loss at around 1.5 or less.
        
        self.layers = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 16),
            nn.ReLU(),
            nn.Linear(16, 8),
            nn.ReLU(),
            nn.Linear(8, 4),
            nn.ReLU(),
            nn.Linear(4, 2),
            nn.ReLU(),
            nn.Linear(2, 1)
        )

    def forward(self, x):
        x = self.layers(x)
        x = x.squeeze(1)
        return x

In [36]:
def select_feat(train_data, valid_data, test_data, select_all=True):
    '''Selects useful features to perform regression'''
    y_train, y_valid = train_data[:,-1], valid_data[:,-1]
    raw_x_train, raw_x_valid, raw_x_test = train_data[:,:-1], valid_data[:,:-1], test_data

    if select_all:
        feat_idx = list(range(raw_x_train.shape[1]))
    else:
        feat_idx = [0,1,2,3,4] # TODO: Select suitable feature columns.
        
    return raw_x_train[:,feat_idx], raw_x_valid[:,feat_idx], raw_x_test[:,feat_idx], y_train, y_valid

In [37]:
def trainer(train_loader, valid_loader, model, k, device):

    criterion = nn.MSELoss(reduction='mean') # Define your loss function, do not modify this.
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-5, momentum=0.9) 

    writer = SummaryWriter() # Writer of tensoboard.

    if not os.path.isdir('./models'):
        os.mkdir('./models') # Create directory of saving models.

    n_epochs, best_loss, step, early_stop_count = 3000, math.inf, 0, 0

    for epoch in range(n_epochs):
        model.train() # Set your model to train mode.
        loss_record = []

        # tqdm is a package to visualize your training progress.
        train_pbar = tqdm(train_loader, position=0, leave=True)

        for x, y in train_pbar:
            optimizer.zero_grad()               # Set gradient to zero.
            x, y = x.to(device), y.to(device)   # Move your data to device. 
            pred = model(x)             
            loss = criterion(pred, y)
            loss.backward()                     # Compute gradient(backpropagation).
            optimizer.step()                    # Update parameters.
            step += 1
            loss_record.append(loss.detach().item())
            
            # Display current epoch number and loss on tqdm progress bar.
            train_pbar.set_description(f'Epoch [{epoch+1}/{n_epochs}]')
            train_pbar.set_postfix({'loss': loss.detach().item()})

        mean_train_loss = sum(loss_record)/len(loss_record)
        writer.add_scalar('Loss/train', mean_train_loss, step)

        model.eval() # Set your model to evaluation mode.
        loss_record = []
        for x, y in valid_loader:
            x, y = x.to(device), y.to(device)
            with torch.no_grad():
                pred = model(x)
                loss = criterion(pred, y)

            loss_record.append(loss.item())
            
        mean_valid_loss = sum(loss_record)/len(loss_record)
        print(f'Epoch [{epoch+1}/{n_epochs}]: Train loss: {mean_train_loss:.4f}, Valid loss: {mean_valid_loss:.4f}')
        writer.add_scalar('Loss/valid', mean_valid_loss, step)

        if mean_valid_loss < best_loss:
            best_loss = mean_valid_loss
            torch.save(model.state_dict(), f'./models/model{k}.ckpt') # Save your best model
            print('Saving model with loss {:.3f}...'.format(best_loss))
            early_stop_count = 0
        else: 
            early_stop_count += 1

        if early_stop_count >= 400:
            print('\nModel is not improving, so we halt the training session.')
            return

In [38]:
# Set seed for reproducibility
same_seed(5201314)

train_data, test_data = pd.read_csv('../input/ml2022spring-hw1/covid.train.csv').values, pd.read_csv('../input/ml2022spring-hw1/covid.test.csv').values
train_data, valid_data = kFold(train_data)

for k in range(kfold_n):
    # Select features
    x_train, x_valid, x_test, y_train, y_valid = select_feat(train_data[k], valid_data[k], test_data, True)

    train_dataset, valid_dataset, test_dataset = CustomDataset(x_train, y_train), \
                                                CustomDataset(x_valid, y_valid), \
                                                CustomDataset(x_test)

    # Pytorch data loader loads pytorch dataset into batches.
    train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, pin_memory=True)
    valid_loader = DataLoader(valid_dataset, batch_size=256, shuffle=True, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, pin_memory=True)
    
    print('************************************')
    print(f'{k+1} Fold:')
    print('************************************')
    
    model = My_Model(input_dim=x_train.shape[1]).to(device) # put your model and data on the same computation device.
    trainer(train_loader, valid_loader, model, k, device)

  # This is added back by InteractiveShellApp.init_path()


************************************
1 Fold:
************************************


Epoch [1/3000]: 100%|██████████| 9/9 [00:00<00:00, 152.88it/s, loss=111]


Epoch [1/3000]: Train loss: 141.5305, Valid loss: 90.2108
Saving model with loss 90.211...


Epoch [2/3000]: 100%|██████████| 9/9 [00:00<00:00, 154.75it/s, loss=67.9]


Epoch [2/3000]: Train loss: 101.5903, Valid loss: 73.5882
Saving model with loss 73.588...


Epoch [3/3000]: 100%|██████████| 9/9 [00:00<00:00, 151.62it/s, loss=52.8]


Epoch [3/3000]: Train loss: 57.4230, Valid loss: 62.6156
Saving model with loss 62.616...


Epoch [4/3000]: 100%|██████████| 9/9 [00:00<00:00, 148.85it/s, loss=43.5]


Epoch [4/3000]: Train loss: 50.1467, Valid loss: 79.9454


Epoch [5/3000]: 100%|██████████| 9/9 [00:00<00:00, 143.39it/s, loss=48.9]


Epoch [5/3000]: Train loss: 49.0189, Valid loss: 69.0409


Epoch [6/3000]: 100%|██████████| 9/9 [00:00<00:00, 151.79it/s, loss=41.6]


Epoch [6/3000]: Train loss: 47.5916, Valid loss: 77.0545


Epoch [7/3000]: 100%|██████████| 9/9 [00:00<00:00, 150.58it/s, loss=46.9]


Epoch [7/3000]: Train loss: 46.7998, Valid loss: 67.9822


Epoch [8/3000]: 100%|██████████| 9/9 [00:00<00:00, 156.19it/s, loss=42.8]


Epoch [8/3000]: Train loss: 46.3555, Valid loss: 66.7080


Epoch [9/3000]: 100%|██████████| 9/9 [00:00<00:00, 143.16it/s, loss=53.1]


Epoch [9/3000]: Train loss: 46.7445, Valid loss: 68.5353


Epoch [10/3000]: 100%|██████████| 9/9 [00:00<00:00, 139.05it/s, loss=47.2]


Epoch [10/3000]: Train loss: 46.3235, Valid loss: 62.3344
Saving model with loss 62.334...


Epoch [11/3000]: 100%|██████████| 9/9 [00:00<00:00, 160.39it/s, loss=45.7]


Epoch [11/3000]: Train loss: 46.0433, Valid loss: 59.1754
Saving model with loss 59.175...


Epoch [12/3000]: 100%|██████████| 9/9 [00:00<00:00, 163.85it/s, loss=39.4]


Epoch [12/3000]: Train loss: 45.5717, Valid loss: 69.9897


Epoch [13/3000]: 100%|██████████| 9/9 [00:00<00:00, 150.15it/s, loss=45.7]


Epoch [13/3000]: Train loss: 45.8614, Valid loss: 60.2751


Epoch [14/3000]: 100%|██████████| 9/9 [00:00<00:00, 139.25it/s, loss=47]


Epoch [14/3000]: Train loss: 45.6868, Valid loss: 61.5784


Epoch [15/3000]: 100%|██████████| 9/9 [00:00<00:00, 149.42it/s, loss=50.3]


Epoch [15/3000]: Train loss: 45.8956, Valid loss: 68.4353


Epoch [16/3000]: 100%|██████████| 9/9 [00:00<00:00, 154.77it/s, loss=49.4]


Epoch [16/3000]: Train loss: 45.6110, Valid loss: 59.0975
Saving model with loss 59.098...


Epoch [17/3000]: 100%|██████████| 9/9 [00:00<00:00, 167.42it/s, loss=41.8]


Epoch [17/3000]: Train loss: 45.0257, Valid loss: 63.5397


Epoch [18/3000]: 100%|██████████| 9/9 [00:00<00:00, 171.34it/s, loss=47.7]

Epoch [18/3000]: Train loss: 45.3075, Valid loss: 56.5872





Saving model with loss 56.587...


Epoch [19/3000]: 100%|██████████| 9/9 [00:00<00:00, 156.37it/s, loss=38.4]


Epoch [19/3000]: Train loss: 44.5568, Valid loss: 54.2388
Saving model with loss 54.239...


Epoch [20/3000]: 100%|██████████| 9/9 [00:00<00:00, 163.74it/s, loss=46]


Epoch [20/3000]: Train loss: 45.0506, Valid loss: 63.1402


Epoch [21/3000]: 100%|██████████| 9/9 [00:00<00:00, 134.46it/s, loss=34.7]


Epoch [21/3000]: Train loss: 44.2008, Valid loss: 52.5735
Saving model with loss 52.573...


Epoch [22/3000]: 100%|██████████| 9/9 [00:00<00:00, 128.78it/s, loss=42.6]


Epoch [22/3000]: Train loss: 44.6045, Valid loss: 62.4627


Epoch [23/3000]: 100%|██████████| 9/9 [00:00<00:00, 130.73it/s, loss=49.3]


Epoch [23/3000]: Train loss: 44.7178, Valid loss: 55.3816


Epoch [24/3000]: 100%|██████████| 9/9 [00:00<00:00, 127.65it/s, loss=39.9]


Epoch [24/3000]: Train loss: 43.9608, Valid loss: 59.1045


Epoch [25/3000]: 100%|██████████| 9/9 [00:00<00:00, 127.88it/s, loss=35.3]


Epoch [25/3000]: Train loss: 43.5389, Valid loss: 60.0237


Epoch [26/3000]: 100%|██████████| 9/9 [00:00<00:00, 161.11it/s, loss=38.1]


Epoch [26/3000]: Train loss: 43.4577, Valid loss: 56.9088


Epoch [27/3000]: 100%|██████████| 9/9 [00:00<00:00, 162.56it/s, loss=37.7]


Epoch [27/3000]: Train loss: 43.2087, Valid loss: 58.1603


Epoch [28/3000]: 100%|██████████| 9/9 [00:00<00:00, 137.07it/s, loss=46.7]


Epoch [28/3000]: Train loss: 43.5592, Valid loss: 56.5074


Epoch [29/3000]: 100%|██████████| 9/9 [00:00<00:00, 121.52it/s, loss=36.2]


Epoch [29/3000]: Train loss: 42.7559, Valid loss: 58.5701


Epoch [30/3000]: 100%|██████████| 9/9 [00:00<00:00, 124.91it/s, loss=54.6]


Epoch [30/3000]: Train loss: 43.6865, Valid loss: 55.4819


Epoch [31/3000]: 100%|██████████| 9/9 [00:00<00:00, 126.38it/s, loss=41.9]


Epoch [31/3000]: Train loss: 42.5644, Valid loss: 53.1751


Epoch [32/3000]: 100%|██████████| 9/9 [00:00<00:00, 154.34it/s, loss=40.4]


Epoch [32/3000]: Train loss: 42.0807, Valid loss: 50.0304
Saving model with loss 50.030...


Epoch [33/3000]: 100%|██████████| 9/9 [00:00<00:00, 151.92it/s, loss=46.8]


Epoch [33/3000]: Train loss: 42.1311, Valid loss: 43.6711
Saving model with loss 43.671...


Epoch [34/3000]: 100%|██████████| 9/9 [00:00<00:00, 128.67it/s, loss=39.2]


Epoch [34/3000]: Train loss: 41.3228, Valid loss: 46.2717


Epoch [35/3000]: 100%|██████████| 9/9 [00:00<00:00, 138.09it/s, loss=48.2]


Epoch [35/3000]: Train loss: 41.7217, Valid loss: 39.9704
Saving model with loss 39.970...


Epoch [36/3000]: 100%|██████████| 9/9 [00:00<00:00, 136.73it/s, loss=45.3]


Epoch [36/3000]: Train loss: 41.1260, Valid loss: 45.4940


Epoch [37/3000]: 100%|██████████| 9/9 [00:00<00:00, 157.78it/s, loss=39.4]


Epoch [37/3000]: Train loss: 40.1370, Valid loss: 43.6212


Epoch [38/3000]: 100%|██████████| 9/9 [00:00<00:00, 122.79it/s, loss=38.6]


Epoch [38/3000]: Train loss: 39.5401, Valid loss: 38.1171
Saving model with loss 38.117...


Epoch [39/3000]: 100%|██████████| 9/9 [00:00<00:00, 113.14it/s, loss=37.5]


Epoch [39/3000]: Train loss: 38.6688, Valid loss: 29.6782
Saving model with loss 29.678...


Epoch [40/3000]: 100%|██████████| 9/9 [00:00<00:00, 128.82it/s, loss=39.7]


Epoch [40/3000]: Train loss: 38.4043, Valid loss: 38.4173


Epoch [41/3000]: 100%|██████████| 9/9 [00:00<00:00, 105.73it/s, loss=33]


Epoch [41/3000]: Train loss: 36.8340, Valid loss: 30.5304


Epoch [42/3000]: 100%|██████████| 9/9 [00:00<00:00, 118.90it/s, loss=29.3]


Epoch [42/3000]: Train loss: 35.5245, Valid loss: 29.3057
Saving model with loss 29.306...


Epoch [43/3000]: 100%|██████████| 9/9 [00:00<00:00, 108.34it/s, loss=35.7]


Epoch [43/3000]: Train loss: 35.0375, Valid loss: 29.1247
Saving model with loss 29.125...


Epoch [44/3000]: 100%|██████████| 9/9 [00:00<00:00, 148.58it/s, loss=31.9]


Epoch [44/3000]: Train loss: 33.7163, Valid loss: 22.9810
Saving model with loss 22.981...


Epoch [45/3000]: 100%|██████████| 9/9 [00:00<00:00, 153.94it/s, loss=33.7]


Epoch [45/3000]: Train loss: 32.5066, Valid loss: 19.0603
Saving model with loss 19.060...


Epoch [46/3000]: 100%|██████████| 9/9 [00:00<00:00, 155.73it/s, loss=34.9]


Epoch [46/3000]: Train loss: 30.9808, Valid loss: 17.8611
Saving model with loss 17.861...


Epoch [47/3000]: 100%|██████████| 9/9 [00:00<00:00, 151.80it/s, loss=24.1]


Epoch [47/3000]: Train loss: 29.4048, Valid loss: 14.0112
Saving model with loss 14.011...


Epoch [48/3000]: 100%|██████████| 9/9 [00:00<00:00, 149.63it/s, loss=25]


Epoch [48/3000]: Train loss: 28.0681, Valid loss: 20.3244


Epoch [49/3000]: 100%|██████████| 9/9 [00:00<00:00, 152.52it/s, loss=30.3]


Epoch [49/3000]: Train loss: 26.1521, Valid loss: 19.8903


Epoch [50/3000]: 100%|██████████| 9/9 [00:00<00:00, 153.37it/s, loss=30.7]


Epoch [50/3000]: Train loss: 24.7149, Valid loss: 17.2505


Epoch [51/3000]: 100%|██████████| 9/9 [00:00<00:00, 151.56it/s, loss=21.7]


Epoch [51/3000]: Train loss: 22.7397, Valid loss: 18.4203


Epoch [52/3000]: 100%|██████████| 9/9 [00:00<00:00, 152.20it/s, loss=21.6]


Epoch [52/3000]: Train loss: 22.5549, Valid loss: 18.6629


Epoch [53/3000]: 100%|██████████| 9/9 [00:00<00:00, 155.06it/s, loss=21.9]


Epoch [53/3000]: Train loss: 19.6777, Valid loss: 15.6980


Epoch [54/3000]: 100%|██████████| 9/9 [00:00<00:00, 146.66it/s, loss=16.6]


Epoch [54/3000]: Train loss: 15.9889, Valid loss: 14.7522


Epoch [55/3000]: 100%|██████████| 9/9 [00:00<00:00, 140.07it/s, loss=11.7]


Epoch [55/3000]: Train loss: 13.3049, Valid loss: 14.6178


Epoch [56/3000]: 100%|██████████| 9/9 [00:00<00:00, 161.44it/s, loss=7.52]


Epoch [56/3000]: Train loss: 11.4813, Valid loss: 12.3295
Saving model with loss 12.329...


Epoch [57/3000]: 100%|██████████| 9/9 [00:00<00:00, 164.32it/s, loss=9.22]


Epoch [57/3000]: Train loss: 13.3556, Valid loss: 11.6439
Saving model with loss 11.644...


Epoch [58/3000]: 100%|██████████| 9/9 [00:00<00:00, 156.52it/s, loss=11.3]


Epoch [58/3000]: Train loss: 12.2225, Valid loss: 9.5435
Saving model with loss 9.543...


Epoch [59/3000]: 100%|██████████| 9/9 [00:00<00:00, 165.35it/s, loss=38.1]


Epoch [59/3000]: Train loss: 17.6795, Valid loss: 10.6981


Epoch [60/3000]: 100%|██████████| 9/9 [00:00<00:00, 143.15it/s, loss=55]


Epoch [60/3000]: Train loss: 45.2403, Valid loss: 18.5078


Epoch [61/3000]: 100%|██████████| 9/9 [00:00<00:00, 145.56it/s, loss=32.8]


Epoch [61/3000]: Train loss: 39.8655, Valid loss: 28.1714


Epoch [62/3000]: 100%|██████████| 9/9 [00:00<00:00, 144.01it/s, loss=31.8]


Epoch [62/3000]: Train loss: 35.1739, Valid loss: 16.1904


Epoch [63/3000]: 100%|██████████| 9/9 [00:00<00:00, 135.17it/s, loss=26.2]


Epoch [63/3000]: Train loss: 26.2608, Valid loss: 15.3759


Epoch [64/3000]: 100%|██████████| 9/9 [00:00<00:00, 160.72it/s, loss=18]


Epoch [64/3000]: Train loss: 20.6720, Valid loss: 15.3478


Epoch [65/3000]: 100%|██████████| 9/9 [00:00<00:00, 162.67it/s, loss=15.1]


Epoch [65/3000]: Train loss: 15.6526, Valid loss: 19.0242


Epoch [66/3000]: 100%|██████████| 9/9 [00:00<00:00, 151.79it/s, loss=11]


Epoch [66/3000]: Train loss: 11.8015, Valid loss: 33.2161


Epoch [67/3000]: 100%|██████████| 9/9 [00:00<00:00, 159.98it/s, loss=8.76]


Epoch [67/3000]: Train loss: 9.2964, Valid loss: 40.8425


Epoch [68/3000]: 100%|██████████| 9/9 [00:00<00:00, 133.47it/s, loss=10.1]


Epoch [68/3000]: Train loss: 8.6781, Valid loss: 42.6014


Epoch [69/3000]: 100%|██████████| 9/9 [00:00<00:00, 159.80it/s, loss=8.67]


Epoch [69/3000]: Train loss: 8.4028, Valid loss: 29.9316


Epoch [70/3000]: 100%|██████████| 9/9 [00:00<00:00, 158.90it/s, loss=6.68]


Epoch [70/3000]: Train loss: 6.9513, Valid loss: 23.4935


Epoch [71/3000]: 100%|██████████| 9/9 [00:00<00:00, 157.06it/s, loss=5.77]


Epoch [71/3000]: Train loss: 6.5957, Valid loss: 31.2282


Epoch [72/3000]: 100%|██████████| 9/9 [00:00<00:00, 138.16it/s, loss=13]


Epoch [72/3000]: Train loss: 9.0383, Valid loss: 19.1703


Epoch [73/3000]: 100%|██████████| 9/9 [00:00<00:00, 169.98it/s, loss=14.4]


Epoch [73/3000]: Train loss: 32.5713, Valid loss: 24.1663


Epoch [74/3000]: 100%|██████████| 9/9 [00:00<00:00, 128.15it/s, loss=23.6]


Epoch [74/3000]: Train loss: 32.0949, Valid loss: 27.6768


Epoch [75/3000]: 100%|██████████| 9/9 [00:00<00:00, 157.75it/s, loss=15.4]


Epoch [75/3000]: Train loss: 18.1665, Valid loss: 24.3084


Epoch [76/3000]: 100%|██████████| 9/9 [00:00<00:00, 165.64it/s, loss=9.76]


Epoch [76/3000]: Train loss: 13.2576, Valid loss: 29.3430


Epoch [77/3000]: 100%|██████████| 9/9 [00:00<00:00, 164.32it/s, loss=8.43]


Epoch [77/3000]: Train loss: 10.4407, Valid loss: 28.1697


Epoch [78/3000]: 100%|██████████| 9/9 [00:00<00:00, 161.69it/s, loss=8.41]


Epoch [78/3000]: Train loss: 8.1294, Valid loss: 25.6334


Epoch [79/3000]: 100%|██████████| 9/9 [00:00<00:00, 148.85it/s, loss=9.01]


Epoch [79/3000]: Train loss: 21.0925, Valid loss: 18.8973


Epoch [80/3000]: 100%|██████████| 9/9 [00:00<00:00, 160.15it/s, loss=12.2]


Epoch [80/3000]: Train loss: 17.0365, Valid loss: 16.0381


Epoch [81/3000]: 100%|██████████| 9/9 [00:00<00:00, 157.23it/s, loss=10.8]


Epoch [81/3000]: Train loss: 13.4197, Valid loss: 21.6076


Epoch [82/3000]: 100%|██████████| 9/9 [00:00<00:00, 171.54it/s, loss=8.45]


Epoch [82/3000]: Train loss: 10.3202, Valid loss: 29.0021


Epoch [83/3000]: 100%|██████████| 9/9 [00:00<00:00, 161.86it/s, loss=7.07]


Epoch [83/3000]: Train loss: 7.5136, Valid loss: 30.2107


Epoch [84/3000]: 100%|██████████| 9/9 [00:00<00:00, 131.58it/s, loss=6.79]


Epoch [84/3000]: Train loss: 6.7821, Valid loss: 26.8061


Epoch [85/3000]: 100%|██████████| 9/9 [00:00<00:00, 129.92it/s, loss=6.94]


Epoch [85/3000]: Train loss: 6.6316, Valid loss: 22.6570


Epoch [86/3000]: 100%|██████████| 9/9 [00:00<00:00, 143.87it/s, loss=4.77]


Epoch [86/3000]: Train loss: 6.7734, Valid loss: 20.8203


Epoch [87/3000]: 100%|██████████| 9/9 [00:00<00:00, 141.84it/s, loss=7.73]


Epoch [87/3000]: Train loss: 6.2976, Valid loss: 23.5647


Epoch [88/3000]: 100%|██████████| 9/9 [00:00<00:00, 136.87it/s, loss=5.64]


Epoch [88/3000]: Train loss: 6.1416, Valid loss: 22.6107


Epoch [89/3000]: 100%|██████████| 9/9 [00:00<00:00, 139.83it/s, loss=6.7]


Epoch [89/3000]: Train loss: 6.1132, Valid loss: 20.7542


Epoch [90/3000]: 100%|██████████| 9/9 [00:00<00:00, 148.28it/s, loss=7.09]


Epoch [90/3000]: Train loss: 6.3856, Valid loss: 22.9633


Epoch [91/3000]: 100%|██████████| 9/9 [00:00<00:00, 163.87it/s, loss=6.2]


Epoch [91/3000]: Train loss: 7.6345, Valid loss: 19.3250


Epoch [92/3000]: 100%|██████████| 9/9 [00:00<00:00, 160.65it/s, loss=5.17]


Epoch [92/3000]: Train loss: 6.2443, Valid loss: 22.3730


Epoch [93/3000]: 100%|██████████| 9/9 [00:00<00:00, 147.78it/s, loss=9.13]


Epoch [93/3000]: Train loss: 6.8976, Valid loss: 22.5965


Epoch [94/3000]: 100%|██████████| 9/9 [00:00<00:00, 142.91it/s, loss=5.61]


Epoch [94/3000]: Train loss: 6.1678, Valid loss: 21.8279


Epoch [95/3000]: 100%|██████████| 9/9 [00:00<00:00, 136.58it/s, loss=7.19]


Epoch [95/3000]: Train loss: 6.3534, Valid loss: 22.6518


Epoch [96/3000]: 100%|██████████| 9/9 [00:00<00:00, 146.62it/s, loss=6.82]


Epoch [96/3000]: Train loss: 7.3206, Valid loss: 20.7100


Epoch [97/3000]: 100%|██████████| 9/9 [00:00<00:00, 155.85it/s, loss=5.76]


Epoch [97/3000]: Train loss: 6.2918, Valid loss: 21.9073


Epoch [98/3000]: 100%|██████████| 9/9 [00:00<00:00, 162.08it/s, loss=5.73]


Epoch [98/3000]: Train loss: 6.3788, Valid loss: 24.4196


Epoch [99/3000]: 100%|██████████| 9/9 [00:00<00:00, 163.68it/s, loss=5.77]


Epoch [99/3000]: Train loss: 6.0155, Valid loss: 23.6115


Epoch [100/3000]: 100%|██████████| 9/9 [00:00<00:00, 160.19it/s, loss=6.49]


Epoch [100/3000]: Train loss: 7.5620, Valid loss: 24.8072


Epoch [101/3000]: 100%|██████████| 9/9 [00:00<00:00, 156.07it/s, loss=6.77]


Epoch [101/3000]: Train loss: 10.8445, Valid loss: 18.3395


Epoch [102/3000]: 100%|██████████| 9/9 [00:00<00:00, 152.72it/s, loss=5.7]


Epoch [102/3000]: Train loss: 7.1555, Valid loss: 22.5582


Epoch [103/3000]: 100%|██████████| 9/9 [00:00<00:00, 161.39it/s, loss=5.94]


Epoch [103/3000]: Train loss: 6.6057, Valid loss: 23.4598


Epoch [104/3000]: 100%|██████████| 9/9 [00:00<00:00, 135.25it/s, loss=6.45]


Epoch [104/3000]: Train loss: 5.6881, Valid loss: 20.6644


Epoch [105/3000]: 100%|██████████| 9/9 [00:00<00:00, 125.99it/s, loss=5.24]


Epoch [105/3000]: Train loss: 7.8204, Valid loss: 20.6689


Epoch [106/3000]: 100%|██████████| 9/9 [00:00<00:00, 161.58it/s, loss=4.85]


Epoch [106/3000]: Train loss: 6.3017, Valid loss: 20.1807


Epoch [107/3000]: 100%|██████████| 9/9 [00:00<00:00, 141.11it/s, loss=5.85]


Epoch [107/3000]: Train loss: 7.1204, Valid loss: 15.6230


Epoch [108/3000]: 100%|██████████| 9/9 [00:00<00:00, 140.06it/s, loss=7.69]


Epoch [108/3000]: Train loss: 6.3238, Valid loss: 21.2504


Epoch [109/3000]: 100%|██████████| 9/9 [00:00<00:00, 136.84it/s, loss=5.5]


Epoch [109/3000]: Train loss: 6.4903, Valid loss: 20.7025


Epoch [110/3000]: 100%|██████████| 9/9 [00:00<00:00, 34.83it/s, loss=6.2] 


Epoch [110/3000]: Train loss: 6.2955, Valid loss: 21.6872


Epoch [111/3000]: 100%|██████████| 9/9 [00:00<00:00, 138.72it/s, loss=5.75]


Epoch [111/3000]: Train loss: 5.7241, Valid loss: 18.4031


Epoch [112/3000]: 100%|██████████| 9/9 [00:00<00:00, 163.36it/s, loss=4.94]


Epoch [112/3000]: Train loss: 6.0742, Valid loss: 21.5108


Epoch [113/3000]: 100%|██████████| 9/9 [00:00<00:00, 152.48it/s, loss=6.83]


Epoch [113/3000]: Train loss: 6.0689, Valid loss: 19.6917


Epoch [114/3000]: 100%|██████████| 9/9 [00:00<00:00, 157.09it/s, loss=13.9]


Epoch [114/3000]: Train loss: 8.8899, Valid loss: 18.8455


Epoch [115/3000]: 100%|██████████| 9/9 [00:00<00:00, 145.37it/s, loss=5.49]


Epoch [115/3000]: Train loss: 7.3128, Valid loss: 18.3086


Epoch [116/3000]: 100%|██████████| 9/9 [00:00<00:00, 156.67it/s, loss=6.25]


Epoch [116/3000]: Train loss: 5.8964, Valid loss: 20.3150


Epoch [117/3000]: 100%|██████████| 9/9 [00:00<00:00, 162.93it/s, loss=5.4]


Epoch [117/3000]: Train loss: 5.7688, Valid loss: 21.4279


Epoch [118/3000]: 100%|██████████| 9/9 [00:00<00:00, 155.79it/s, loss=6.91]


Epoch [118/3000]: Train loss: 5.6185, Valid loss: 23.8538


Epoch [119/3000]: 100%|██████████| 9/9 [00:00<00:00, 154.79it/s, loss=4.39]


Epoch [119/3000]: Train loss: 6.0172, Valid loss: 16.6658


Epoch [120/3000]: 100%|██████████| 9/9 [00:00<00:00, 147.43it/s, loss=7.52]


Epoch [120/3000]: Train loss: 6.9684, Valid loss: 19.4248


Epoch [121/3000]: 100%|██████████| 9/9 [00:00<00:00, 130.67it/s, loss=4.7]


Epoch [121/3000]: Train loss: 6.1177, Valid loss: 19.8798


Epoch [122/3000]: 100%|██████████| 9/9 [00:00<00:00, 150.13it/s, loss=5.84]


Epoch [122/3000]: Train loss: 6.0908, Valid loss: 19.8285


Epoch [123/3000]: 100%|██████████| 9/9 [00:00<00:00, 143.54it/s, loss=5.26]


Epoch [123/3000]: Train loss: 5.3111, Valid loss: 18.4691


Epoch [124/3000]: 100%|██████████| 9/9 [00:00<00:00, 164.54it/s, loss=6.02]


Epoch [124/3000]: Train loss: 5.3851, Valid loss: 20.9739


Epoch [125/3000]: 100%|██████████| 9/9 [00:00<00:00, 144.56it/s, loss=5.92]


Epoch [125/3000]: Train loss: 6.0641, Valid loss: 18.2423


Epoch [126/3000]: 100%|██████████| 9/9 [00:00<00:00, 140.35it/s, loss=9.57]


Epoch [126/3000]: Train loss: 6.3935, Valid loss: 21.7481


Epoch [127/3000]: 100%|██████████| 9/9 [00:00<00:00, 146.26it/s, loss=5.91]


Epoch [127/3000]: Train loss: 7.4959, Valid loss: 19.9996


Epoch [128/3000]: 100%|██████████| 9/9 [00:00<00:00, 154.16it/s, loss=6.11]


Epoch [128/3000]: Train loss: 7.6948, Valid loss: 20.4532


Epoch [129/3000]: 100%|██████████| 9/9 [00:00<00:00, 159.15it/s, loss=4.41]


Epoch [129/3000]: Train loss: 5.5628, Valid loss: 17.4214


Epoch [130/3000]: 100%|██████████| 9/9 [00:00<00:00, 164.38it/s, loss=5.43]


Epoch [130/3000]: Train loss: 5.4047, Valid loss: 25.5662


Epoch [131/3000]: 100%|██████████| 9/9 [00:00<00:00, 120.76it/s, loss=6.06]


Epoch [131/3000]: Train loss: 6.0683, Valid loss: 21.7019


Epoch [132/3000]: 100%|██████████| 9/9 [00:00<00:00, 150.12it/s, loss=4.75]


Epoch [132/3000]: Train loss: 5.1951, Valid loss: 17.7464


Epoch [133/3000]: 100%|██████████| 9/9 [00:00<00:00, 154.96it/s, loss=5.38]


Epoch [133/3000]: Train loss: 6.0967, Valid loss: 22.2517


Epoch [134/3000]: 100%|██████████| 9/9 [00:00<00:00, 152.57it/s, loss=6.23]


Epoch [134/3000]: Train loss: 7.4318, Valid loss: 16.9467


Epoch [135/3000]: 100%|██████████| 9/9 [00:00<00:00, 173.07it/s, loss=9.18]


Epoch [135/3000]: Train loss: 7.9818, Valid loss: 17.0381


Epoch [136/3000]: 100%|██████████| 9/9 [00:00<00:00, 171.19it/s, loss=7.37]


Epoch [136/3000]: Train loss: 6.2856, Valid loss: 23.4187


Epoch [137/3000]: 100%|██████████| 9/9 [00:00<00:00, 169.40it/s, loss=3.49]


Epoch [137/3000]: Train loss: 5.3079, Valid loss: 19.6893


Epoch [138/3000]: 100%|██████████| 9/9 [00:00<00:00, 169.60it/s, loss=4.15]


Epoch [138/3000]: Train loss: 5.4232, Valid loss: 17.6507


Epoch [139/3000]: 100%|██████████| 9/9 [00:00<00:00, 153.91it/s, loss=4.31]


Epoch [139/3000]: Train loss: 5.4949, Valid loss: 17.9160


Epoch [140/3000]: 100%|██████████| 9/9 [00:00<00:00, 156.84it/s, loss=4.43]


Epoch [140/3000]: Train loss: 5.4802, Valid loss: 18.9629


Epoch [141/3000]: 100%|██████████| 9/9 [00:00<00:00, 153.32it/s, loss=5.8]


Epoch [141/3000]: Train loss: 6.1307, Valid loss: 17.6542


Epoch [142/3000]: 100%|██████████| 9/9 [00:00<00:00, 157.55it/s, loss=4.89]


Epoch [142/3000]: Train loss: 7.1443, Valid loss: 19.8424


Epoch [143/3000]: 100%|██████████| 9/9 [00:00<00:00, 164.93it/s, loss=4.37]


Epoch [143/3000]: Train loss: 5.5621, Valid loss: 17.2400


Epoch [144/3000]: 100%|██████████| 9/9 [00:00<00:00, 170.60it/s, loss=4.96]


Epoch [144/3000]: Train loss: 5.3272, Valid loss: 22.1838


Epoch [145/3000]: 100%|██████████| 9/9 [00:00<00:00, 151.17it/s, loss=5.11]


Epoch [145/3000]: Train loss: 5.2098, Valid loss: 19.2515


Epoch [146/3000]: 100%|██████████| 9/9 [00:00<00:00, 96.38it/s, loss=4.5]


Epoch [146/3000]: Train loss: 5.3807, Valid loss: 19.2268


Epoch [147/3000]: 100%|██████████| 9/9 [00:00<00:00, 153.89it/s, loss=5.63]


Epoch [147/3000]: Train loss: 5.1443, Valid loss: 18.4593


Epoch [148/3000]: 100%|██████████| 9/9 [00:00<00:00, 167.71it/s, loss=5.23]


Epoch [148/3000]: Train loss: 5.0014, Valid loss: 18.5423


Epoch [149/3000]: 100%|██████████| 9/9 [00:00<00:00, 166.06it/s, loss=4.87]


Epoch [149/3000]: Train loss: 5.2260, Valid loss: 19.3448


Epoch [150/3000]: 100%|██████████| 9/9 [00:00<00:00, 160.02it/s, loss=6.26]


Epoch [150/3000]: Train loss: 6.7247, Valid loss: 16.8604


Epoch [151/3000]: 100%|██████████| 9/9 [00:00<00:00, 161.05it/s, loss=3.84]


Epoch [151/3000]: Train loss: 6.7974, Valid loss: 17.8536


Epoch [152/3000]: 100%|██████████| 9/9 [00:00<00:00, 161.12it/s, loss=5.15]


Epoch [152/3000]: Train loss: 7.2484, Valid loss: 18.4457


Epoch [153/3000]: 100%|██████████| 9/9 [00:00<00:00, 154.15it/s, loss=6.89]


Epoch [153/3000]: Train loss: 5.8995, Valid loss: 16.3107


Epoch [154/3000]: 100%|██████████| 9/9 [00:00<00:00, 158.41it/s, loss=5.81]


Epoch [154/3000]: Train loss: 5.4218, Valid loss: 16.7200


Epoch [155/3000]: 100%|██████████| 9/9 [00:00<00:00, 162.58it/s, loss=6.93]


Epoch [155/3000]: Train loss: 5.7100, Valid loss: 15.6825


Epoch [156/3000]: 100%|██████████| 9/9 [00:00<00:00, 150.37it/s, loss=8.86]


Epoch [156/3000]: Train loss: 6.6819, Valid loss: 19.1962


Epoch [157/3000]: 100%|██████████| 9/9 [00:00<00:00, 144.86it/s, loss=6.52]


Epoch [157/3000]: Train loss: 6.4492, Valid loss: 20.1929


Epoch [158/3000]: 100%|██████████| 9/9 [00:00<00:00, 151.28it/s, loss=3.98]


Epoch [158/3000]: Train loss: 5.1194, Valid loss: 23.8043


Epoch [159/3000]: 100%|██████████| 9/9 [00:00<00:00, 149.54it/s, loss=4.08]


Epoch [159/3000]: Train loss: 4.8665, Valid loss: 19.1913


Epoch [160/3000]: 100%|██████████| 9/9 [00:00<00:00, 141.32it/s, loss=5.9]


Epoch [160/3000]: Train loss: 5.1394, Valid loss: 20.7459


Epoch [161/3000]: 100%|██████████| 9/9 [00:00<00:00, 131.76it/s, loss=5.57]


Epoch [161/3000]: Train loss: 7.0078, Valid loss: 19.8526


Epoch [162/3000]: 100%|██████████| 9/9 [00:00<00:00, 145.39it/s, loss=4.24]


Epoch [162/3000]: Train loss: 5.0960, Valid loss: 19.3783


Epoch [163/3000]: 100%|██████████| 9/9 [00:00<00:00, 142.17it/s, loss=4.59]


Epoch [163/3000]: Train loss: 6.9185, Valid loss: 20.1684


Epoch [164/3000]: 100%|██████████| 9/9 [00:00<00:00, 164.46it/s, loss=5.16]


Epoch [164/3000]: Train loss: 6.8237, Valid loss: 18.9068


Epoch [165/3000]: 100%|██████████| 9/9 [00:00<00:00, 162.06it/s, loss=4.27]


Epoch [165/3000]: Train loss: 5.5590, Valid loss: 23.0422


Epoch [166/3000]: 100%|██████████| 9/9 [00:00<00:00, 158.90it/s, loss=5.97]


Epoch [166/3000]: Train loss: 5.1106, Valid loss: 17.0591


Epoch [167/3000]: 100%|██████████| 9/9 [00:00<00:00, 130.59it/s, loss=6.03]


Epoch [167/3000]: Train loss: 6.5214, Valid loss: 19.4991


Epoch [168/3000]: 100%|██████████| 9/9 [00:00<00:00, 141.22it/s, loss=3.91]


Epoch [168/3000]: Train loss: 5.7110, Valid loss: 17.1510


Epoch [169/3000]: 100%|██████████| 9/9 [00:00<00:00, 135.91it/s, loss=4.74]


Epoch [169/3000]: Train loss: 5.7864, Valid loss: 18.2059


Epoch [170/3000]: 100%|██████████| 9/9 [00:00<00:00, 133.00it/s, loss=4.28]


Epoch [170/3000]: Train loss: 5.5342, Valid loss: 18.3716


Epoch [171/3000]: 100%|██████████| 9/9 [00:00<00:00, 138.33it/s, loss=6.24]


Epoch [171/3000]: Train loss: 5.5234, Valid loss: 22.0138


Epoch [172/3000]: 100%|██████████| 9/9 [00:00<00:00, 136.43it/s, loss=5.76]


Epoch [172/3000]: Train loss: 5.7872, Valid loss: 16.3206


Epoch [173/3000]: 100%|██████████| 9/9 [00:00<00:00, 150.78it/s, loss=5.97]


Epoch [173/3000]: Train loss: 5.2160, Valid loss: 18.9384


Epoch [174/3000]: 100%|██████████| 9/9 [00:00<00:00, 140.43it/s, loss=7.17]


Epoch [174/3000]: Train loss: 5.7736, Valid loss: 18.3392


Epoch [175/3000]: 100%|██████████| 9/9 [00:00<00:00, 142.69it/s, loss=4.95]


Epoch [175/3000]: Train loss: 5.1911, Valid loss: 20.5985


Epoch [176/3000]: 100%|██████████| 9/9 [00:00<00:00, 142.39it/s, loss=5.67]


Epoch [176/3000]: Train loss: 4.8909, Valid loss: 19.3678


Epoch [177/3000]: 100%|██████████| 9/9 [00:00<00:00, 124.10it/s, loss=3.7]


Epoch [177/3000]: Train loss: 4.6697, Valid loss: 18.2343


Epoch [178/3000]: 100%|██████████| 9/9 [00:00<00:00, 132.66it/s, loss=5.47]


Epoch [178/3000]: Train loss: 4.7629, Valid loss: 18.2129


Epoch [179/3000]: 100%|██████████| 9/9 [00:00<00:00, 138.93it/s, loss=5.82]


Epoch [179/3000]: Train loss: 4.6061, Valid loss: 17.0193


Epoch [180/3000]:  11%|█         | 1/9 [00:00<00:00, 87.87it/s, loss=4.14]


KeyboardInterrupt: 

In [None]:
def predict(test_loader, model, device):
    model.eval() # Set your model to evaluation mode.
    preds = []
    for x in tqdm(test_loader):
        x = x.to(device)                        
        with torch.no_grad():                   
            pred = model(x)                     
            preds.append(pred.detach().cpu())   
    preds = torch.cat(preds, dim=0).numpy()  
    return preds

In [None]:
def save_pred(preds, file):
    with open(file, 'w') as fp:
        writer = csv.writer(fp)
        writer.writerow(['id', 'tested_positive'])
        for i, p in enumerate(preds):
            writer.writerow([i, p])


test_dataset = CustomDataset(test_data, test=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

preds = [0] * (train_dataset.__len__())
for i in range(kfold_n):
    model = My_Model(input_dim=train_dataset.__len__()).to(device)
    model.load_state_dict(torch.load(f'./models/model{i}.ckpt'))
    for j, element in enumerate(predict(test_loader, model, device).tolist()):
        preds[j] += element

for i in range(train_dataset.__len__()):
    # divide by the number of kfold iteration to get average
    preds[i] = preds[i]/kfold_n
    
save_pred(preds, 'pred.csv')