In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import pickle
import talos

with open('train', 'rb') as file:
    train_dict = pickle.load(file, encoding='bytes')

with open('test', 'rb') as file:
    test_dict = pickle.load(file, encoding='bytes')

X_train = train_dict[b'data']
y_train = train_dict[b'coarse_labels']

X_test = test_dict[b'data']
y_test = test_dict[b'coarse_labels']

X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255

y_train = torch.tensor(y_train).long()
y_test = torch.tensor(y_test).long()
y_train = nn.functional.one_hot(y_train, num_classes=100)
y_test = nn.functional.one_hot(y_test, num_classes=100)

class MyModel(nn.Module):
    def __init__(self, units, hidden_activation, activation, loss, optimizer):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(in_features=X_train.shape[1], out_features=units)
        self.fc2 = nn.Linear(in_features=units, out_features=units)
        self.fc3 = nn.Linear(in_features=units, out_features=units)
        self.fc4 = nn.Linear(in_features=units, out_features=units)
        self.fc5 = nn.Linear(in_features=units, out_features=units)
        self.fc6 = nn.Linear(in_features=units, out_features=100)
        self.hidden_activation = getattr(nn, hidden_activation)()
        self.activation = getattr(nn, activation)()
        self.loss = getattr(nn, loss)()
        self.optimizer = getattr(optim, optimizer)(self.parameters())

    def forward(self, x):
        x = self.fc1(x)
        x = self.hidden_activation(x)
        x = self.fc2(x)
        x = self.hidden_activation(x)
        x = self.fc3(x)
        x = self.hidden_activation(x)
        x = self.fc4(x)
        x = self.hidden_activation(x)
        x = self.fc5(x)
        x = self.hidden_activation(x)
        x = self.fc6(x)
        x = self.activation(x)
        return x

def train_model(X_train, y_train, X_val, y_val, params):
    model = MyModel(params['units'], params['hidden_activation'], params['activation'], params['loss'], params['optimizer'])
    num_epochs = 200
    batch_size = params['batch_size']
    train_dataset = torch.utils.data.TensorDataset(torch.tensor(X_train), y_train)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataset = torch.utils.data.TensorDataset(torch.tensor(X_val), y_val)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    for epoch in range(num_epochs):
        model.train()
        for batch, (data, target) in enumerate(train_loader):
            model.optimizer.zero_grad()
            output = model(data)
            loss = model.loss(output, target)
            loss.backward()
            model.optimizer.step()
        model.eval()
        with torch.no_grad():
            val_loss = 0
            correct = 0
            total = 0
            for batch, (data, target) in enumerate(val_loader):
                output = model(data)
                val_loss += model.loss(output, target).item()
                pred = torch.argmax(output, dim=1)
                correct += (pred == torch.argmax)
