In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import matplotlib.pyplot as plt
from tqdm import tqdm

In [63]:
X_train = torch.load("X_train.th")
y_train = torch.load("y_train.th").to(torch.int64)
y_train_one_hot = F.one_hot(y_train, num_classes=5).to(torch.float)

In [65]:
FEATURE_SIZE_ORDER_SEQUENCE = 183
FEATURE_SIZE_BRAND = 100
FEATURE_SIZE_F_1 = 12
FEATURE_SIZE_F_2 = 4
FEATURE_SIZE_F_3 = 100
FEATURE_SIZE_F_4 = 6
FEATURE_SIZE_F_5 = 50

In [118]:
class MyModel(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.encoder_order_sequence = nn.Linear(FEATURE_SIZE_ORDER_SEQUENCE, 20)
        self.encoder_brand = nn.Linear(FEATURE_SIZE_BRAND, 20)
        self.encoder_f_3 = nn.Linear(FEATURE_SIZE_F_3, 20)
        self.encoder_f_5 = nn.Linear(FEATURE_SIZE_F_5, 20)
        
        self.ff1 = nn.Linear(
            4*20 + FEATURE_SIZE_F_1 + FEATURE_SIZE_F_2 + FEATURE_SIZE_F_4, 50
        )
        self.ff2 = nn.Linear(50, 5)
        
        self.start_idxs = torch.cumsum(torch.Tensor([
            0, FEATURE_SIZE_ORDER_SEQUENCE, FEATURE_SIZE_BRAND, FEATURE_SIZE_F_1, FEATURE_SIZE_F_2,
            FEATURE_SIZE_F_3, FEATURE_SIZE_F_4
        ]), dim=0).int()
        self.end_idxs = torch.cumsum(torch.Tensor([
            FEATURE_SIZE_ORDER_SEQUENCE, FEATURE_SIZE_BRAND, FEATURE_SIZE_F_1, FEATURE_SIZE_F_2,
            FEATURE_SIZE_F_3, FEATURE_SIZE_F_4, FEATURE_SIZE_F_5
        ]), dim=0).int()
        
    def forward(self, x):
        inputs = []
        for start_idx, end_idx in zip(self.start_idxs, self.end_idxs):
            inputs.append(x[:, start_idx:end_idx])
        x = torch.hstack([
            self.encoder_order_sequence(inputs[0]), 
            self.encoder_brand(inputs[1]), 
            inputs[2], 
            inputs[3],
            self.encoder_f_3(inputs[4]), 
            inputs[5], 
            self.encoder_f_5(inputs[6]),  
        ])        
        
        x = nn.ReLU()(self.ff1(x))
        #x = nn.ReLU()(self.ff2(x))
        x = nn.Softmax(dim=1)(self.ff2(x))
        
        return x

In [119]:
model = MyModel()
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

In [120]:
epochs = 5
losses = []

for _ in tqdm(range(epochs)):
    episode_loss = 0
    for data_point, target in zip(X_train, y_train_one_hot):
        optimizer.zero_grad()
        pred = model(data_point.reshape(1, -1))
        loss = loss_fn(pred, target.reshape(1, -1))
        loss.backward()
        optimizer.step()
        
        episode_loss += loss.item() / X_train.shape[0]
    losses.append(episode_loss)

100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:23<00:00, 40.74s/it]


In [121]:
losses

[1.2703041239479187,
 1.267582491399523,
 1.267582491399523,
 1.267582491399523,
 1.267582491399523]

In [122]:
pred_distr = torch.zeros(5)

for data_point, target in zip(X_train, y_train_one_hot):
    pred = model(data_point.reshape(1, -1))[0]
    pred_distr[torch.where(pred == 1)[0][0]] += 1

In [123]:
pred_distr

tensor([9745.,    0.,    0.,    0.,    0.])

In [124]:
pred, target

(tensor([1.0000e+00, 5.9010e-12, 1.0237e-11, 5.3830e-12, 4.0505e-12],
        grad_fn=<SelectBackward0>),
 tensor([0., 1., 0., 0., 0.]))

In [125]:
data_point

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 1., 0., 0., 