In [1]:
# imports
import atc_dataloader, atc_model
import torch
from torch.utils.data import DataLoader
from torch import nn
import matplotlib.pyplot as plt
import numpy as np

In [2]:
# set the device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
def train_model(model, device, dataset_train, dataset_test, optimizer, epochs=10):
    
    # define loss function
    criterion = atc_model.PredictionLoss().to(device)
    
    # create DataLoader for batch processing
    batch_size = 32
    dataloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)

    losses_overall = []

    # training
    for epoch in range(epochs):
        losses = []
        for (batch_in_gd, batch_in_vectors, batch_out_vectors) in dataloader:
            # forward
            outputs = model(batch_in_vectors)
            
            # loss
            loss = criterion(outputs, batch_out_vectors)
            if epoch > 0 : nn.utils.clip_grad_value_(model.parameters(), 0.0005)

            # Backpropagation and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
            
        losses_overall.append(np.mean(losses))
        print(f'Epoch {epoch+1}/{epochs}, Loss: {np.mean(losses)}')

    # plot graph
    plt.plot(np.arange(0,epochs,1),losses_overall)
    plt.title('Loss function')
    plt.show()

In [15]:
# load data
TRAIN_IN="data/train_in.csv"
TRAIN_OUT="data/train_out.csv"

data_train = atc_dataloader.ATCDataset(TRAIN_IN, TRAIN_OUT)

TEST_IN="data/test_in.csv"
TEST_OUT="data/test_out.csv"

data_test = atc_dataloader.ATCDataset(TEST_IN, TEST_OUT)

In [19]:
# testing format of data
a,b,c = data_train.__getitem__(120552)
print(a)
print(c)

12012131
[0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]


In [20]:
model = atc_model.BaseNN()
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, weight_decay=0.005)

train_model(model, dataset_train=data_train, dataset_test=data_test, optimizer=optimizer, epochs=100)

11862489
