In [2]:
import pandas as pd
import numpy as np

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

from matplotlib import pyplot as plt
from tqdm import tqdm

In [3]:
n_c = 90
hidden_size = 8
file = f'./Dataset/dataset_reduced_{n_c}.json'

### Datasets


In [4]:
# dataloaders
class IPARC(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        return {
            'img_in': torch.tensor(row['input_reduced']).reshape(-1),
            'img_out': torch.tensor(row['output_reduced']).reshape(-1),
            'operation': torch.tensor(row['operation']).reshape(1, -1),
            'kernel': torch.tensor(row['kernel']),
        }

    @staticmethod
    def collate(batch):
        return {
            'img_in': torch.stack([x['img_in'] for x in batch]),
            'img_out': torch.stack([x['img_out'] for x in batch]),
            'operation': torch.stack([x['operation'] for x in batch]),
            'kernel': torch.stack([x['kernel'] for x in batch]),
        }

### Model


In [5]:
class Model(nn.Module):
    def __init__(self, n_c, hdsz):
        super(Model, self).__init__()
        self.op_linreg = nn.Linear(4 * (n_c + n_c) + hdsz, 1)
        self.kernel_linreg = nn.Linear(4 * (n_c + n_c) + hdsz, 8)
        self.hidden_encoder = nn.Linear(hdsz + 1 + 8, hdsz)

    def forward(self, img_in, img_out, op_prev, kernel_prev, hidden):
        hidden = torch.tanh(self.hidden_encoder(torch.cat([hidden, op_prev, kernel_prev], dim=-1)))

        cat = torch.cat([img_in, img_out, hidden], dim=-1)
        op_logit = self.op_linreg(cat)
        kernel_logit = self.kernel_linreg(cat)

        return op_logit, kernel_logit, hidden

### Training Loop


In [10]:
# trainig loop
def train(model, train_loader, test_loader, epochs=10, lr=0.001, device='cpu'):
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    losses = {'train': [], 'test': []}

    for epoch in range(epochs):
        model.train()
        loss_tot = 0
        for batch in train_loader:
            bs = batch['img_in'].shape[0]

            img_in = batch['img_in'].to(device)
            img_out = batch['img_out'].to(device)
            op = batch['operation'].to(device)
            kernel = batch['kernel'].to(device)

            prev_op, prev_kernel = torch.zeros_like(op[:, :, 0], device=device), torch.zeros_like(
                kernel[:, 0], device=device
            )
            hidden = torch.zeros(bs, hidden_size, device=device)

            loss = 0
            n_seq = op.shape[2]
            for i in range(n_seq):
                op_logit, kernel_logit, hidden = model(
                    img_in, img_out, prev_op, prev_kernel, hidden
                )
                op_loss = criterion(op_logit, op[:, :, i].float())
                kernel_loss = criterion(kernel_logit, kernel[:, i].float())
                loss += op_loss + kernel_loss

                prev_op = op[:, :, i]
                prev_kernel = kernel[:, i]

            loss /= n_seq
            loss_tot += loss.detach().cpu().item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        losses['train'].append(loss_tot / len(train_loader))

        model.eval()
        loss_tot = 0
        with torch.no_grad():
            for batch in test_loader:
                bs = batch['img_in'].shape[0]

                img_in = batch['img_in'].to(device)
                img_out = batch['img_out'].to(device)
                op = batch['operation'].to(device)
                kernel = batch['kernel'].to(device)

                prev_op, prev_kernel = torch.zeros_like(op[:, :, 0],
                                                        device=device), torch.zeros_like(
                                                            kernel[:, 0], device=device
                                                        )
                hidden = torch.zeros(bs, hidden_size, device=device)

                loss = 0
                n_seq = op.shape[2]
                for i in range(n_seq):
                    op_logit, kernel_logit, hidden = model(
                        img_in, img_out, prev_op, prev_kernel, hidden
                    )
                    op_loss = criterion(op_logit, op[:, :, i].float())
                    kernel_loss = criterion(kernel_logit, kernel[:, i].float())
                    loss += op_loss + kernel_loss

                    prev_op = op[:, :, i]
                    prev_kernel = kernel[:, i]

                loss /= n_seq
                loss_tot += loss.detach().cpu().item()
        losses['test'].append(loss_tot / len(test_loader))

        print(f'Epoch {epoch+1}/{epochs}, Train Loss: {losses["train"][-1]:.4f}, Test Loss: {losses["test"][-1]:.4f}')

    return losses

### K-Fold Cross Validation


In [7]:
df = pd.read_json(file)
df.head(1)
df = df.sample(frac=1)  # shuffle


In [12]:
def k_fold(df, n_c, hidden_size, epochs=10, lr=0.01, k=1, idx=0):
    window = k
    test_df = df[idx * window:(idx + 1) * window]
    train_df = pd.concat([df[:idx * window], df[(idx + 1) * window:]])

    train_loader = DataLoader(IPARC(train_df), batch_size=128, shuffle=True, collate_fn=IPARC.collate)
    test_loader = DataLoader(IPARC(test_df), batch_size=128, shuffle=False, collate_fn=IPARC.collate)

    model = Model(n_c, hidden_size)
    losses = train(model, train_loader, test_loader, epochs, lr)

    return losses

In [13]:
n_epochs = 100
k = 4

loss_dict = {'train': np.zeros(n_epochs), 'test': np.zeros(n_epochs)}

for i in (pbar := tqdm(range(len(df) // k))):
    losses = k_fold(df, n_c, hidden_size, epochs=n_epochs, lr=0.01, k=k, idx=i)
    loss_dict['train'] += np.array(losses['train'])
    loss_dict['test'] += np.array(losses['test'])
    pbar.set_description(
        f'Loss Train: {losses["train"][-1]:.4f}, Loss Test: {losses["test"][-1]:.4f}'
    )

loss_dict['train'] /= len(df) // k
loss_dict['test'] /= len(df) // k

plt.plot(loss_dict['train'], label='train')
plt.plot(loss_dict['test'], label='test')
plt.legend()
plt.show()

  0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 1/100, Train Loss: 1.6670, Test Loss: 1.6438
Epoch 2/100, Train Loss: 1.3669, Test Loss: 1.4184
Epoch 3/100, Train Loss: 1.2614, Test Loss: 1.3389
Epoch 4/100, Train Loss: 1.2179, Test Loss: 1.1891
Epoch 5/100, Train Loss: 1.1811, Test Loss: 1.1252
Epoch 6/100, Train Loss: 1.1503, Test Loss: 1.2056
Epoch 7/100, Train Loss: 1.1260, Test Loss: 1.0856
Epoch 8/100, Train Loss: 1.1119, Test Loss: 1.0603
Epoch 9/100, Train Loss: 1.1042, Test Loss: 1.0950
Epoch 10/100, Train Loss: 1.0939, Test Loss: 0.9801
Epoch 11/100, Train Loss: 1.0813, Test Loss: 1.1468
Epoch 12/100, Train Loss: 1.0755, Test Loss: 0.9938
Epoch 13/100, Train Loss: 1.0688, Test Loss: 1.1292
Epoch 14/100, Train Loss: 1.0630, Test Loss: 1.0333
Epoch 15/100, Train Loss: 1.0586, Test Loss: 1.0122
Epoch 16/100, Train Loss: 1.0538, Test Loss: 0.9938
Epoch 17/100, Train Loss: 1.0441, Test Loss: 1.0562
Epoch 18/100, Train Loss: 1.0387, Test Loss: 1.2119
Epoch 19/100, Train Loss: 1.0377, Test Loss: 1.1365
Epoch 20/100, Train L

Loss Train: 0.8856, Loss Test: 0.8249:   0%|          | 1/2500 [01:53<78:45:42, 113.46s/it]

Epoch 100/100, Train Loss: 0.8856, Test Loss: 0.8249
Epoch 1/100, Train Loss: 1.6925, Test Loss: 1.5769
Epoch 2/100, Train Loss: 1.3730, Test Loss: 1.2843
Epoch 3/100, Train Loss: 1.2593, Test Loss: 1.2844
Epoch 4/100, Train Loss: 1.2095, Test Loss: 1.1262
Epoch 5/100, Train Loss: 1.1656, Test Loss: 1.1195
Epoch 6/100, Train Loss: 1.1344, Test Loss: 1.1487
Epoch 7/100, Train Loss: 1.1093, Test Loss: 1.1426
Epoch 8/100, Train Loss: 1.0913, Test Loss: 1.1432
Epoch 9/100, Train Loss: 1.0818, Test Loss: 1.2013
Epoch 10/100, Train Loss: 1.0657, Test Loss: 1.1058
Epoch 11/100, Train Loss: 1.0527, Test Loss: 1.1793
Epoch 12/100, Train Loss: 1.0490, Test Loss: 1.0476
Epoch 13/100, Train Loss: 1.0391, Test Loss: 1.1158
Epoch 14/100, Train Loss: 1.0362, Test Loss: 1.1070
Epoch 15/100, Train Loss: 1.0261, Test Loss: 0.9874
Epoch 16/100, Train Loss: 1.0224, Test Loss: 1.0473
Epoch 17/100, Train Loss: 1.0186, Test Loss: 1.1298


Loss Train: 0.8856, Loss Test: 0.8249:   0%|          | 1/2500 [02:13<92:50:40, 133.75s/it]


KeyboardInterrupt: 