In [1]:
import pandas as pd
import numpy as np

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchmetrics import Accuracy


from matplotlib import pyplot as plt
from tqdm import tqdm

In [2]:
n_c = 60
# hidden_size = 16
file = f'./Dataset/dataset5k_reduced_{n_c}.json'

### Datasets


In [3]:
# dataloaders
class IPARC(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        return {
            'img_in': torch.tensor(row['input_reduced']).reshape(-1),
            'img_out': torch.tensor(row['output_reduced']).reshape(-1),
            'operation': torch.tensor(row['operation']).reshape(1, -1),
            'kernel': torch.tensor(row['kernel']),
        }

    @staticmethod
    def collate(batch):
        return {
            'img_in': torch.stack([x['img_in'] for x in batch]),
            'img_out': torch.stack([x['img_out'] for x in batch]),
            'operation': torch.stack([x['operation'] for x in batch]),
            'kernel': torch.stack([x['kernel'] for x in batch]),
        }

### Model


In [4]:
class Model(nn.Module):
    def __init__(self, n_c):
        super(Model, self).__init__()
        self.op_linreg = nn.Linear(4 * (n_c + n_c) + 1 + 8, 1)
        self.kernel_linreg = nn.Linear(4 * (n_c + n_c) + 1 + 8, 8)
        # self.hidden_encoder = nn.Linear(hdsz + 1 + 8, hdsz)

    def forward(self, img_in, img_out, op_prev, kernel_prev):
        # hidden = torch.tanh(self.hidden_encoder(torch.cat([op_prev, kernel_prev], dim=-1)))

        cat = torch.cat([img_in, img_out, op_prev, kernel_prev], dim=-1)
        op_logit = self.op_linreg(cat).squeeze(-1)
        kernel_logit = self.kernel_linreg(cat)

        return op_logit, kernel_logit

### Training Loop


In [5]:
# trainig loop
def train(model, train_loader, test_loader, epochs=10, lr=0.001, device='cpu'):
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion_bce = nn.BCEWithLogitsLoss()
    criterion_ce = nn.CrossEntropyLoss()
    acc_bin = Accuracy(task='binary').to(device)
    acc_multi = Accuracy(task='multiclass', num_classes=8).to(device)

    losses = {'train': [], 'test': []}
    metrics = {'train': {'operation': [],'kernel': [], }, 'test': { 'operation': [], 'kernel': [] }}

    for epoch in (pbar := tqdm(range(epochs))):
        model.train()
        loss_tot, op_acc_tot, kernel_acc_tot = 0, 0, 0
        for batch in train_loader:
            acc_bin.reset()
            acc_multi.reset()
            bs = batch['img_in'].shape[0]

            img_in = batch['img_in'].to(device)
            img_out = batch['img_out'].to(device)
            op = batch['operation'].to(device)
            kernel = batch['kernel'].to(device)
            

            prev_op, prev_kernel = torch.zeros_like(op[:, :, 0], device=device), torch.zeros_like(
                kernel[:, 0], device=device
            )
            # hidden = torch.zeros(bs, hidden_size, device=device)

            loss, op_acc, kernel_acc = 0, 0, 0
            n_seq = op.shape[2]
            for i in range(n_seq):
                op_logit, kernel_logit = model(
                    img_in, img_out, prev_op, prev_kernel
                )
                op_loss = criterion_bce(op_logit, op[:, 0, i].float())
                kernel_loss = criterion_ce(kernel_logit, ((kernel[:, i] == 1).nonzero(as_tuple=True)[1]).long())
                loss += op_loss + kernel_loss

                op_acc += acc_bin(op_logit, op[:, 0, i].float())
                kernel_acc += acc_multi(kernel_logit, ((kernel[:, i] == 1).nonzero(as_tuple=True)[1]).long())
                
                prev_op = op[:, :, i]
                prev_kernel = kernel[:, i]

            loss /= n_seq
            loss_tot += loss.detach().cpu().item()
            
            op_acc /= n_seq
            op_acc_tot += op_acc.detach().cpu().item()
            
            kernel_acc /= n_seq
            kernel_acc_tot += kernel_acc.detach().cpu().item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        losses['train'].append(loss_tot / len(train_loader))
        metrics['train']['operation'].append(op_acc_tot / (len(train_loader)))
        metrics['train']['kernel'].append(kernel_acc_tot / (len(train_loader)))

        model.eval()
        loss_tot, op_acc_tot, kernel_acc_tot = 0, 0, 0
        with torch.no_grad():
            for batch in test_loader:
                acc_bin.reset()
                acc_multi.reset()
                bs = batch['img_in'].shape[0]

                img_in = batch['img_in'].to(device)
                img_out = batch['img_out'].to(device)
                op = batch['operation'].to(device)
                kernel = batch['kernel'].to(device)

                prev_op, prev_kernel = torch.zeros_like(op[:, :, 0], device=device), torch.zeros_like(
                    kernel[:, 0], device=device
                )
                # hidden = torch.zeros(bs, hidden_size, device=device)

                loss, op_acc, kernel_acc = 0, 0, 0
                n_seq = op.shape[2]
                for i in range(n_seq):
                    op_logit, kernel_logit = model(
                        img_in, img_out, prev_op, prev_kernel
                    )
                    op_loss = criterion_bce(op_logit, op[:, 0, i].float())
                    kernel_loss = criterion_ce(kernel_logit, ((kernel[:, i] == 1).nonzero(as_tuple=True)[1]).long())
                    loss += op_loss + kernel_loss

                    op_acc += acc_bin(op_logit, op[:, 0, i].float())
                    kernel_acc += acc_multi(kernel_logit, ((kernel[:, i] == 1).nonzero(as_tuple=True)[1]).long())
                    
                    prev_op = op[:, :, i]
                    prev_kernel = kernel[:, i]

                loss /= n_seq
                loss_tot += loss.detach().cpu().item()
                
                op_acc /= n_seq
                op_acc_tot += op_acc.detach().cpu().item()
                
                kernel_acc /= n_seq
                kernel_acc_tot += kernel_acc.detach().cpu().item()
            losses['test'].append(loss_tot / len(test_loader))
            metrics['test']['operation'].append(op_acc_tot / (len(test_loader)))
            metrics['test']['kernel'].append(kernel_acc_tot / (len(test_loader)))

        pbar.set_description(f'{epoch + 1} | tr-loss: {losses["train"][-1]:.4f} | tr-op: {metrics["train"]["operation"][-1]:.4f} | te-op: {metrics["test"]["operation"][-1]:.4f} | tr-ker: {metrics["train"]["kernel"][-1]:.4f} | te-ker: {metrics["test"]["kernel"][-1]:.4f}')

    return losses, metrics

### K-Fold Cross Validation


In [6]:
df = pd.read_json(file)
df.head(1)
df = df.sample(frac=1)  # shuffle

In [7]:
def k_fold(df, n_c, epochs=10, lr=0.01, k=1, idx=0, device='cpu'):
    window = k
    test_df = df[idx * window:(idx + 1) * window]
    train_df = pd.concat([df[:idx * window], df[(idx + 1) * window:]])

    train_loader = DataLoader(IPARC(train_df), batch_size=16, shuffle=True, collate_fn=IPARC.collate)
    test_loader = DataLoader(IPARC(test_df), batch_size=16, shuffle=False, collate_fn=IPARC.collate)

    model = Model(n_c)
    losses, metrics = train(model, train_loader, test_loader, epochs, lr, device)

    return losses, metrics

In [8]:
n_epochs = 40
k = 500
device = 'cuda' if torch.cuda.is_available() else 'cpu'

loss_dict = {'train': np.zeros(n_epochs), 'test': np.zeros(n_epochs)}
metrics_dict = {'train': {'operation': np.zeros(n_epochs), 'kernel': np.zeros(n_epochs)}, 'test': {'operation': np.zeros(n_epochs), 'kernel': np.zeros(n_epochs)}}

for i in range(len(df) // k):
    losses, metrics = k_fold(df, n_c, epochs=n_epochs, lr=0.01, k=k, idx=i, device=device)
    loss_dict['train'] += np.array(losses['train'])
    loss_dict['test'] += np.array(losses['test'])
    metrics_dict['train']['operation'] += np.array(metrics['train']['operation'])
    metrics_dict['train']['kernel'] += np.array(metrics['train']['kernel'])
    metrics_dict['test']['operation'] += np.array(metrics['test']['operation'])
    metrics_dict['test']['kernel'] += np.array(metrics['test']['kernel'])

loss_dict['train'] /= len(df) // k
loss_dict['test'] /= len(df) // k


print('train loss', loss_dict['train'][-1])
print('test loss', loss_dict['test'][-1])
plt.plot(loss_dict['train'], label='train')
plt.plot(loss_dict['test'], label='test')
plt.legend()
plt.title('Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()

print('train operation accuracy', metrics_dict['train']['operation'][-1])
print('test operation accuracy', metrics_dict['test']['operation'][-1])
plt.plot(metrics_dict['train']['operation'], label='train')
plt.plot(metrics_dict['test']['operation'], label='test')
plt.legend()
plt.title('Operation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Operation Accuracy')
plt.show()

print('train kernel accuracy', metrics_dict['train']['kernel'][-1])
print('test kernel accuracy', metrics_dict['test']['kernel'][-1])
plt.plot(metrics_dict['train']['kernel'], label='train')
plt.plot(metrics_dict['test']['kernel'], label='test')
plt.legend()
plt.title('Kernel Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Kernel Accuracy')
plt.show()

1 | tr-loss: 2.4016 | tr-op: 0.8029 | te-op: 0.8730 | tr-ker: 0.2752 | te-ker: 0.2988:   2%|▎         | 1/40 [00:23<15:09, 23.32s/it]


KeyboardInterrupt: 