In [1]:
%matplotlib inline

import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent))

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

from sklearn.model_selection import train_test_split
from itertools import product
import collections
import math
import random
import matplotlib.pyplot as plt
import numpy as np

from utils import ModelTrainer
from datasets import avGFPDataset, GB1Dataset

In [2]:
class FCN(nn.Module):
    def __init__(self, n, multiplier=2, batch_norm=False):
        super(FCN, self).__init__()
        self.fc1 = nn.Linear(n, multiplier*n)
        self.fc2 = nn.Linear(multiplier*n, multiplier*n)
        self.fc3 = nn.Linear(multiplier*n, n)
        self.fc4 = nn.Linear(n, 1)

        torch.nn.init.xavier_uniform_(self.fc1.weight)
        torch.nn.init.xavier_uniform_(self.fc2.weight)
        torch.nn.init.xavier_uniform_(self.fc3.weight)
        torch.nn.init.xavier_uniform_(self.fc4.weight)
        
        self.batch_norm = batch_norm
        if self.batch_norm:
            self.bn1 = nn.BatchNorm1d(multiplier*n)
            self.bn2 = nn.BatchNorm1d(multiplier*n)
            self.bn3 = nn.BatchNorm1d(n)

    def forward(self, x):
        if self.batch_norm:
            x = self.bn1(F.leaky_relu(self.fc1(x)))
            x = self.bn2(F.leaky_relu(self.fc2(x)))
            x = self.bn3(F.leaky_relu(self.fc3(x)))
            x = self.fc4(x)
        else:
            x = F.leaky_relu(self.fc1(x))
            x = F.leaky_relu(self.fc2(x))
            x = F.leaky_relu(self.fc3(x))
            x = self.fc4(x)

        return x.reshape(-1)

In [None]:
fix_seed = 1
random_seed = 11

train_size = 2000

# config = {
#     "training_method": "hashing",
#     "b": 10,
#     "lr": 0.01, 
#     "weight_decay": 0, 
#     "hadamard_lambda": 0.0001,
#     "num_epochs": 100,
#     "random_seed": random_seed,
#     "fix_seed": fix_seed,
#     "train_size": train_size,
#     "epoch_iterations": 50,
# }

config = {
    "training_method": "normal",
    "b": 10,
    "lr": 0.01, 
    "weight_decay": 0, 
    "num_epochs": 100,
    "random_seed": random_seed,
    "fix_seed": fix_seed,
    "train_size": train_size,
    "epoch_iterations": 40,
    "dataset": "GB1"
}

# config = {
#     "training_method": "EN-S",
#     "SPRIGHT_d": 3,
#     "rho": 0.01,
#     "b": 10,
#     "lr": 0.01, 
#     "weight_decay": 0, 
#     "hadamard_lambda": 1,
#     "num_epochs": 10,
#     "random_seed": random_seed,
#     "fix_seed": fix_seed,
#     "train_size": 500,
#     "epoch_iterations": 50,
# }

config = {
    "training_method": "EN-S_data",
    "b": 10,
    "lr": 0.01, 
    "SPRIGHT_d": 3,
    "weight_decay": 0, 
    "hadamard_lambda": 0.1,
    "num_epochs": 100,
    "random_seed": random_seed,
    "fix_seed": fix_seed,
    "train_size": train_size,
    "epoch_iterations": 20,
    "run":2
}

    
# Dataset
torch.manual_seed(config["fix_seed"])
dataset = GB1Dataset()
dataset_size = len(dataset)
train_ds, val_ds = torch.utils.data.random_split(dataset, [config["train_size"], dataset_size - config["train_size"]])

# Set batch size
config["batch_size"] = math.ceil(config["train_size"] / config["epoch_iterations"])

# Train model
torch.manual_seed(config["random_seed"]) # Seed for network initialization
in_dim = dataset.X.shape[1]
model = FCN(in_dim, 1, batch_norm=False)
trainer = ModelTrainer(model, train_ds, val_ds, config=config,  plot_results=True, checkpoint_cache=True)
model = trainer.train_model()


Loaded dataset from cache.
Loading SPRIGHT samples from cache ...
#0 - Train Loss: 0.199, R2: -0.054	Validation Loss: 0.134, R2: 0.137
#1 - Train Loss: 0.133, R2: 0.295	Validation Loss: 0.101, R2: 0.351
#2 - Train Loss: 0.059, R2: 0.687	Validation Loss: 0.096, R2: 0.384
#3 - Train Loss: 0.025, R2: 0.867	Validation Loss: 0.094, R2: 0.397
#4 - Train Loss: 0.025, R2: 0.868	Validation Loss: 0.092, R2: 0.411
#5 - Train Loss: 0.026, R2: 0.864	Validation Loss: 0.089, R2: 0.427
#6 - Train Loss: 0.010, R2: 0.949	Validation Loss: 0.080, R2: 0.487
#7 - Train Loss: 0.007, R2: 0.964	Validation Loss: 0.085, R2: 0.458
#8 - Train Loss: 0.006, R2: 0.970	Validation Loss: 0.077, R2: 0.509
#9 - Train Loss: 0.005, R2: 0.976	Validation Loss: 0.083, R2: 0.467
#10 - Train Loss: 0.006, R2: 0.970	Validation Loss: 0.084, R2: 0.460
#11 - Train Loss: 0.005, R2: 0.975	Validation Loss: 0.080, R2: 0.486
#12 - Train Loss: 0.003, R2: 0.984	Validation Loss: 0.083, R2: 0.466
#13 - Train Loss: 0.003, R2: 0.984	Validation 