In [1]:
from experiments.chordmixer import ChordMixerNet
from experiments.training_utils import count_params, seed_everything, init_weights, train_epoch, eval_model
from experiments.dataloader_utils import DatasetCreator, concater_collate
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import roc_auc_score, accuracy_score
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

import pandas as pd
import torch
from torch import nn, optim

# Genbank Example

In [2]:
data_train = pd.read_pickle(f'experiments/data/Carassius_Labeo_train.pkl')
data_test = pd.read_pickle(f'experiments/data/Carassius_Labeo_test.pkl')
max_seq_len = max(max(data_train['len']), max(data_test['len']))

In [3]:
data_train.head()

Unnamed: 0,sequence,label,len,bin
130668,"[1, 1, 0, 3, 0, 3, 2, 3, 0, 1, 3, 0, 0, 3, 0, ...",0,600,3
46015,"[0, 3, 0, 3, 2, 3, 0, 3, 3, 0, 3, 1, 0, 1, 1, ...",0,1056,7
2648,"[1, 2, 2, 0, 1, 2, 1, 2, 3, 2, 2, 1, 1, 0, 3, ...",0,1412,9
8205,"[2, 0, 2, 3, 2, 1, 1, 1, 1, 2, 3, 2, 3, 2, 4, ...",1,447,2
122869,"[0, 3, 2, 1, 3, 0, 0, 0, 0, 2, 3, 0, 1, 3, 3, ...",0,1383,9


In [4]:
torch.cuda.set_device(0)
device = 'cuda:{}'.format(0) if torch.cuda.is_available() else 'cpu'
print('set up device:', device)

net = ChordMixerNet(
    problem='genbank',
    vocab_size=20,
    max_seq_len=max_seq_len,
    embedding_size=128,
    track_size=16,
    hidden_size=128,
    mlp_dropout=0.0,
    layer_dropout=0.0,
    n_class=2
)
net = net.to(device)
net.apply(init_weights)


class_weights = compute_class_weight('balanced', classes=[0, 1], y=data_train['label'])
print('class weights:', class_weights)

loss = nn.CrossEntropyLoss(
    weight=torch.tensor(class_weights, dtype=torch.float).to(device),
    reduction='mean'
)

optimizer = optim.Adam(
    net.parameters(),
    lr=0.0001
)

set up device: cuda:0
class weights: [0.96255649 1.04047453]


In [5]:
trainset = DatasetCreator(
    df=data_train,
    batch_size=2,
    var_len=True
)

trainloader = DataLoader(
    trainset,
    batch_size=2,
    shuffle=False,
    collate_fn=concater_collate,
    drop_last=False,
    num_workers=4
)


# Prepare the testing loader
testset = DatasetCreator(
    df=data_test,
    batch_size=2,
    var_len=True
)

testloader = DataLoader(
    testset,
    batch_size=2,
    shuffle=False,
    collate_fn=concater_collate,
    drop_last=False,
    num_workers=4
)


In [6]:
def train_epoch(config, net, optimizer, loss, trainloader, device, log_every=50, scheduler=None, problem=None):
    net.train()
    running_loss = 0.0
    n_items_processed = 0
    num_batches = len(trainloader)
    for idx, (X, Y, length, bin) in tqdm(enumerate(trainloader), total=num_batches):
        if problem == 'adding':
            X = X.float().to(device)
            Y = Y.float().to(device)
            output = net(X, length).squeeze()
            output = loss(output, Y)
        else:
            X = X.to(device)
            Y = Y.to(device)            
            output = net(X, length)
            output = loss(output, Y)

        output.backward()
        optimizer.step()
        optimizer.zero_grad()
        if scheduler:
            scheduler.step()

        running_loss += output.item()
        n_items_processed += len(length)

    total_loss = running_loss / num_batches
    print(f'Training loss after epoch: {total_loss}')

def eval_model(config, net, valloader, metric, device, problem) -> float:
    net.eval()

    preds = []
    targets = []
    bins = []

    num_batches = len(valloader)
    for idx, (X, Y, length, bin) in tqdm(enumerate(valloader), total=num_batches, position=0, leave=True, ascii=False):
        if problem == 'adding':
            X = X.float().to(device)
            Y = Y.float().to(device)
            output = net(X, length).squeeze()
            predicted = output
        else:
            X = X.to(device)
            Y = Y.to(device)  
            output = net(X, length) 
            _, predicted = output.max(1)
        
        targets.extend(Y.detach().cpu().numpy().flatten())
        preds.extend(predicted.detach().cpu().numpy().flatten())
        bins.extend(bin)

    total_metric = metric(preds, targets)
    return total_metric


In [7]:
for epoch in range(15):
    print(f'Starting epoch {epoch+1}')
    config=None
    train_epoch(config, net, optimizer, loss, trainloader, device=device, log_every=10000, problem='genbank')
    test_roc_auc = eval_model(config, net, testloader, metric=roc_auc_score, device=device, problem='genbank') 
    print(f'Epoch {epoch+1} completed. Test ROCAUC: {test_roc_auc}')

Starting epoch 1


100%|██████████| 4458/4458 [01:23<00:00, 53.36it/s]

Training loss after epoch: 0.5930328630110909



100%|██████████| 1395/1395 [00:11<00:00, 120.57it/s]

Epoch 1 completed. Test ROCAUC: 0.8715695476814205
Starting epoch 2



100%|██████████| 4458/4458 [01:20<00:00, 55.33it/s]

Training loss after epoch: 0.31280290801909577



100%|██████████| 1395/1395 [00:11<00:00, 119.64it/s]

Epoch 2 completed. Test ROCAUC: 0.902502478755048
Starting epoch 3



100%|██████████| 4458/4458 [01:21<00:00, 54.73it/s]

Training loss after epoch: 0.2229982015722905



100%|██████████| 1395/1395 [00:14<00:00, 99.49it/s] 

Epoch 3 completed. Test ROCAUC: 0.9140791183062801
Starting epoch 4



100%|██████████| 4458/4458 [01:22<00:00, 54.24it/s]

Training loss after epoch: 0.16965537135811298



100%|██████████| 1395/1395 [00:14<00:00, 99.12it/s] 

Epoch 4 completed. Test ROCAUC: 0.9153168445214707
Starting epoch 5



100%|██████████| 4458/4458 [01:21<00:00, 54.63it/s]

Training loss after epoch: 0.11841853988325589



100%|██████████| 1395/1395 [00:14<00:00, 98.38it/s] 

Epoch 5 completed. Test ROCAUC: 0.9261526615547487
Starting epoch 6



100%|██████████| 4458/4458 [01:21<00:00, 54.75it/s]

Training loss after epoch: 0.07616942251129151



100%|██████████| 1395/1395 [00:14<00:00, 98.86it/s] 

Epoch 6 completed. Test ROCAUC: 0.9241066430607808
Starting epoch 7



100%|██████████| 4458/4458 [01:23<00:00, 53.50it/s]

Training loss after epoch: 0.05381972214043256



100%|██████████| 1395/1395 [00:10<00:00, 129.16it/s]

Epoch 7 completed. Test ROCAUC: 0.9215720136437474
Starting epoch 8



100%|██████████| 4458/4458 [01:21<00:00, 54.80it/s]

Training loss after epoch: 0.0501274761766566



100%|██████████| 1395/1395 [00:13<00:00, 99.77it/s] 

Epoch 8 completed. Test ROCAUC: 0.923776822560785
Starting epoch 9



100%|██████████| 4458/4458 [01:20<00:00, 55.12it/s]

Training loss after epoch: 0.03865994253573443



100%|██████████| 1395/1395 [00:13<00:00, 101.55it/s]

Epoch 9 completed. Test ROCAUC: 0.9302007666756696
Starting epoch 10



100%|██████████| 4458/4458 [01:20<00:00, 55.22it/s]

Training loss after epoch: 0.02962015546689079



100%|██████████| 1395/1395 [00:11<00:00, 118.37it/s]

Epoch 10 completed. Test ROCAUC: 0.9301623595598033
Starting epoch 11



100%|██████████| 4458/4458 [01:20<00:00, 55.08it/s]

Training loss after epoch: 0.02706635002271226



100%|██████████| 1395/1395 [00:11<00:00, 121.74it/s]

Epoch 11 completed. Test ROCAUC: 0.9305088746812644
Starting epoch 12



100%|██████████| 4458/4458 [01:21<00:00, 54.47it/s]

Training loss after epoch: 0.023921714036276322



100%|██████████| 1395/1395 [00:10<00:00, 128.84it/s]

Epoch 12 completed. Test ROCAUC: 0.9218935134056078
Starting epoch 13



100%|██████████| 4458/4458 [01:22<00:00, 54.27it/s]

Training loss after epoch: 0.025380650376584



100%|██████████| 1395/1395 [00:11<00:00, 123.03it/s]


Epoch 13 completed. Test ROCAUC: 0.9215460896538118
Starting epoch 14


100%|██████████| 4458/4458 [01:20<00:00, 55.09it/s]

Training loss after epoch: 0.013640293765500106



100%|██████████| 1395/1395 [00:10<00:00, 129.78it/s]

Epoch 14 completed. Test ROCAUC: 0.9316046498177037
Starting epoch 15



100%|██████████| 4458/4458 [01:22<00:00, 54.17it/s]

Training loss after epoch: 0.018828823426635284



100%|██████████| 1395/1395 [00:12<00:00, 108.12it/s]

Epoch 15 completed. Test ROCAUC: 0.9339238501446672





# Adding Example

In [4]:
data_train = pd.read_pickle(f'experiments/data/adding_200_train.pkl')
data_test = pd.read_pickle(f'experiments/data/adding_200_test.pkl')
max_seq_len = max(max(data_train['len']), max(data_test['len']))

In [5]:
data_train.head()

Unnamed: 0,sequence,label,len,bin
13414,"[[-0.08982194422382483, 0.0], [-0.841596420741...",0.241234,476,6
56779,"[[-0.6032521601163512, 0.0], [-0.3534035945595...",0.557188,296,4
11976,"[[0.6875824449807342, 0.0], [0.732433410329577...",0.48989,205,2
29866,"[[-0.11423020701527808, 0.0], [-0.743413113459...",0.169627,178,1
1418,"[[-0.9588989728128303, 0.0], [-0.0150490739209...",0.605129,186,2


In [6]:
torch.cuda.set_device(1)
device = 'cuda:{}'.format(1) if torch.cuda.is_available() else 'cpu'
print('set up device:', device)

net = ChordMixerNet(
    problem='adding',
    vocab_size=1,
    max_seq_len=max_seq_len,
    embedding_size=196,
    track_size=16,
    hidden_size=128,
    mlp_dropout=0.0,
    layer_dropout=0.0,
    n_class=1
)
net = net.to(device)
net.apply(init_weights)

loss = nn.MSELoss()

optimizer = optim.Adam(
    net.parameters(),
    lr=0.0003
)

set up device: cuda:1


In [7]:
trainset = DatasetCreator(
    df=data_train,
    batch_size=20,
    var_len=True
)

trainloader = DataLoader(
    trainset,
    batch_size=20,
    shuffle=False,
    collate_fn=concater_collate,
    drop_last=False,
    num_workers=4
)


# Prepare the testing loader
testset = DatasetCreator(
    df=data_test,
    batch_size=20,
    var_len=True
)

testloader = DataLoader(
    testset,
    batch_size=20,
    shuffle=False,
    collate_fn=concater_collate,
    drop_last=False,
    num_workers=4
)


In [8]:
import numpy as np
def accuracy_adding(predictions, target):
    predictions = np.array(predictions)
    target = np.array(target)
    score = (np.abs(predictions - target) < 0.04).mean()
    return score

for epoch in range(20):
    print(f'Starting epoch {epoch+1}')
    config=None
    train_epoch(config, net, optimizer, loss, trainloader, device=device, log_every=100000, problem='adding')
    accuracy = eval_model(config, net, testloader, metric=accuracy_adding, device=device, problem='adding') 
    print(f'Epoch {epoch+1} completed. Test accuracy: {accuracy}')

Starting epoch 1


100%|██████████| 2493/2493 [03:26<00:00, 12.04it/s]

Training loss after epoch: 0.027713460761252782



100%|██████████| 256/256 [00:11<00:00, 23.27it/s]

Epoch 1 completed. Test accuracy: 0.2978515625
Starting epoch 2



100%|██████████| 2493/2493 [03:29<00:00, 11.93it/s]

Training loss after epoch: 0.009715184406176578



100%|██████████| 256/256 [00:13<00:00, 18.54it/s]

Epoch 2 completed. Test accuracy: 0.39140625
Starting epoch 3



100%|██████████| 2493/2493 [03:30<00:00, 11.82it/s]

Training loss after epoch: 0.005981061408553807



100%|██████████| 256/256 [00:10<00:00, 23.32it/s]

Epoch 3 completed. Test accuracy: 0.610546875
Starting epoch 4



100%|██████████| 2493/2493 [03:26<00:00, 12.10it/s]

Training loss after epoch: 0.002842787358276122



100%|██████████| 256/256 [00:10<00:00, 23.81it/s]

Epoch 4 completed. Test accuracy: 0.730078125
Starting epoch 5



100%|██████████| 2493/2493 [03:29<00:00, 11.88it/s]

Training loss after epoch: 0.0016909314767512926



100%|██████████| 256/256 [00:14<00:00, 17.67it/s]

Epoch 5 completed. Test accuracy: 0.879296875
Starting epoch 6



100%|██████████| 2493/2493 [03:31<00:00, 11.77it/s]

Training loss after epoch: 0.001128642675652514



100%|██████████| 256/256 [00:10<00:00, 23.87it/s]

Epoch 6 completed. Test accuracy: 0.8083984375
Starting epoch 7



100%|██████████| 2493/2493 [03:27<00:00, 12.04it/s]

Training loss after epoch: 0.0012362024246399361



100%|██████████| 256/256 [00:10<00:00, 23.90it/s]

Epoch 7 completed. Test accuracy: 0.9296875
Starting epoch 8



100%|██████████| 2493/2493 [03:27<00:00, 12.02it/s]

Training loss after epoch: 0.0006878992494961187



100%|██████████| 256/256 [00:11<00:00, 22.49it/s]

Epoch 8 completed. Test accuracy: 0.9330078125
Starting epoch 9



100%|██████████| 2493/2493 [03:27<00:00, 12.00it/s]

Training loss after epoch: 0.0006834379984630572



100%|██████████| 256/256 [00:10<00:00, 23.58it/s]

Epoch 9 completed. Test accuracy: 0.908984375
Starting epoch 10



100%|██████████| 2493/2493 [03:25<00:00, 12.12it/s]

Training loss after epoch: 0.0005689226818744593



100%|██████████| 256/256 [00:10<00:00, 23.73it/s]

Epoch 10 completed. Test accuracy: 0.9216796875
Starting epoch 11



100%|██████████| 2493/2493 [03:25<00:00, 12.16it/s]

Training loss after epoch: 0.0003845247808538325



100%|██████████| 256/256 [00:14<00:00, 17.67it/s]

Epoch 11 completed. Test accuracy: 0.94609375
Starting epoch 12



100%|██████████| 2493/2493 [03:31<00:00, 11.77it/s]

Training loss after epoch: 0.0004038411665346416



100%|██████████| 256/256 [00:11<00:00, 21.76it/s]

Epoch 12 completed. Test accuracy: 0.976171875
Starting epoch 13



100%|██████████| 2493/2493 [03:27<00:00, 12.04it/s]

Training loss after epoch: 0.0003447734195277123



100%|██████████| 256/256 [00:14<00:00, 17.69it/s]

Epoch 13 completed. Test accuracy: 0.98203125
Starting epoch 14



100%|██████████| 2493/2493 [03:27<00:00, 12.00it/s]

Training loss after epoch: 0.00026726826554729393



100%|██████████| 256/256 [00:10<00:00, 24.05it/s]

Epoch 14 completed. Test accuracy: 0.9796875
Starting epoch 15



100%|██████████| 2493/2493 [03:32<00:00, 11.72it/s]

Training loss after epoch: 0.00031323987477451144



100%|██████████| 256/256 [00:10<00:00, 23.69it/s]

Epoch 15 completed. Test accuracy: 0.994140625
Starting epoch 16



100%|██████████| 2493/2493 [03:25<00:00, 12.14it/s]

Training loss after epoch: 0.00023953229634787445



100%|██████████| 256/256 [00:14<00:00, 18.23it/s]

Epoch 16 completed. Test accuracy: 0.9953125
Starting epoch 17



100%|██████████| 2493/2493 [03:24<00:00, 12.21it/s]

Training loss after epoch: 0.00020064286104835277



100%|██████████| 256/256 [00:10<00:00, 23.81it/s]

Epoch 17 completed. Test accuracy: 0.9970703125
Starting epoch 18



100%|██████████| 2493/2493 [03:26<00:00, 12.09it/s]

Training loss after epoch: 0.0002564679064256286



100%|██████████| 256/256 [00:11<00:00, 23.03it/s]

Epoch 18 completed. Test accuracy: 0.9861328125
Starting epoch 19



100%|██████████| 2493/2493 [03:30<00:00, 11.84it/s]

Training loss after epoch: 0.00015979634654238775



100%|██████████| 256/256 [00:10<00:00, 23.86it/s]

Epoch 19 completed. Test accuracy: 0.998046875
Starting epoch 20



100%|██████████| 2493/2493 [03:24<00:00, 12.18it/s]

Training loss after epoch: 0.0001929583848571533



100%|██████████| 256/256 [00:14<00:00, 17.53it/s]

Epoch 20 completed. Test accuracy: 0.997265625



