In [1]:
from experiments.chordmixer import ChordMixerNet
from experiments.training_utils import count_params, seed_everything, init_weights, train_epoch, eval_model
from experiments.dataloader_utils import DatasetCreator, concater_collate
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import roc_auc_score, accuracy_score
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

import pandas as pd
import torch
from torch import nn, optim

# Genbank Example

In [2]:
data_train = pd.read_pickle(f'experiments/data/Carassius_Labeo_train.pkl')
data_test = pd.read_pickle(f'experiments/data/Carassius_Labeo_test.pkl')
max_seq_len = max(max(data_train['len']), max(data_test['len']))

In [3]:
data_train.head()

Unnamed: 0,sequence,label,len,bin
130668,"[1, 1, 0, 3, 0, 3, 2, 3, 0, 1, 3, 0, 0, 3, 0, ...",0,600,3
46015,"[0, 3, 0, 3, 2, 3, 0, 3, 3, 0, 3, 1, 0, 1, 1, ...",0,1056,7
2648,"[1, 2, 2, 0, 1, 2, 1, 2, 3, 2, 2, 1, 1, 0, 3, ...",0,1412,9
8205,"[2, 0, 2, 3, 2, 1, 1, 1, 1, 2, 3, 2, 3, 2, 4, ...",1,447,2
122869,"[0, 3, 2, 1, 3, 0, 0, 0, 0, 2, 3, 0, 1, 3, 3, ...",0,1383,9


In [4]:
torch.cuda.set_device(0)
device = 'cuda:{}'.format(0) if torch.cuda.is_available() else 'cpu'
print('set up device:', device)

net = ChordMixerNet(
    problem='genbank',
    vocab_size=20,
    max_seq_len=max_seq_len,
    embedding_size=128,
    track_size=16,
    hidden_size=128,
    mlp_dropout=0.0,
    layer_dropout=0.0,
    n_class=2
)
net = net.to(device)
net.apply(init_weights)


class_weights = compute_class_weight('balanced', classes=[0, 1], y=data_train['label'])
print('class weights:', class_weights)

loss = nn.CrossEntropyLoss(
    weight=torch.tensor(class_weights, dtype=torch.float).to(device),
    reduction='mean'
)

optimizer = optim.Adam(
    net.parameters(),
    lr=0.0001
)

set up device: cuda:0
class weights: [0.96255649 1.04047453]


In [5]:
trainset = DatasetCreator(
    df=data_train,
    batch_size=2,
    var_len=True
)

trainloader = DataLoader(
    trainset,
    batch_size=2,
    shuffle=False,
    collate_fn=concater_collate,
    drop_last=False,
    num_workers=4
)


# Prepare the testing loader
testset = DatasetCreator(
    df=data_test,
    batch_size=2,
    var_len=True
)

testloader = DataLoader(
    testset,
    batch_size=2,
    shuffle=False,
    collate_fn=concater_collate,
    drop_last=False,
    num_workers=4
)


In [6]:
def train_epoch(config, net, optimizer, loss, trainloader, device, log_every=50, scheduler=None, problem=None):
    net.train()
    running_loss = 0.0
    n_items_processed = 0
    num_batches = len(trainloader)
    for idx, (X, Y, length, bin) in tqdm(enumerate(trainloader), total=num_batches):
        if problem == 'adding':
            X = X.float().to(device)
            Y = Y.float().to(device)
            output = net(X, length).squeeze()
            output = loss(output, Y)
        else:
            X = X.to(device)
            Y = Y.to(device)            
            output = net(X, length)
            output = loss(output, Y)

        output.backward()
        optimizer.step()
        optimizer.zero_grad()
        if scheduler:
            scheduler.step()

        running_loss += output.item()
        n_items_processed += len(length)

    total_loss = running_loss / num_batches
    print(f'Training loss after epoch: {total_loss}')

def eval_model(config, net, valloader, metric, device, problem) -> float:
    net.eval()

    preds = []
    targets = []
    bins = []

    num_batches = len(valloader)
    for idx, (X, Y, length, bin) in tqdm(enumerate(valloader), total=num_batches, position=0, leave=True, ascii=False):
        if problem == 'adding':
            X = X.float().to(device)
            Y = Y.float().to(device)
            output = net(X, length).squeeze()
            predicted = output
        else:
            X = X.to(device)
            Y = Y.to(device)  
            output = net(X, length) 
            _, predicted = output.max(1)
        
        targets.extend(Y.detach().cpu().numpy().flatten())
        preds.extend(predicted.detach().cpu().numpy().flatten())
        bins.extend(bin)

    total_metric = metric(preds, targets)
    return total_metric


In [7]:
for epoch in range(15):
    print(f'Starting epoch {epoch+1}')
    config=None
    train_epoch(config, net, optimizer, loss, trainloader, device=device, log_every=10000, problem='genbank')
    test_roc_auc = eval_model(config, net, testloader, metric=roc_auc_score, device=device, problem='genbank') 
    print(f'Epoch {epoch+1} completed. Test ROCAUC: {test_roc_auc}')

Starting epoch 1


100%|██████████| 4458/4458 [01:25<00:00, 51.87it/s]

Training loss after epoch: 0.5849942228323071



100%|██████████| 1395/1395 [00:11<00:00, 120.51it/s]

Epoch 1 completed. Test ROCAUC: 0.8899681273387698
Starting epoch 2



100%|██████████| 4458/4458 [01:26<00:00, 51.44it/s]

Training loss after epoch: 0.31362329200395805



100%|██████████| 1395/1395 [00:11<00:00, 120.00it/s]

Epoch 2 completed. Test ROCAUC: 0.9135172994086123
Starting epoch 3



100%|██████████| 4458/4458 [01:25<00:00, 52.13it/s]

Training loss after epoch: 0.22537807819689776



100%|██████████| 1395/1395 [00:11<00:00, 122.66it/s]


Epoch 3 completed. Test ROCAUC: 0.9213178370218303
Starting epoch 4


100%|██████████| 4458/4458 [01:25<00:00, 52.23it/s]

Training loss after epoch: 0.17396101961957255



100%|██████████| 1395/1395 [00:11<00:00, 123.78it/s]

Epoch 4 completed. Test ROCAUC: 0.9200814221325141
Starting epoch 5



100%|██████████| 4458/4458 [01:25<00:00, 51.98it/s]

Training loss after epoch: 0.12504067431678922



100%|██████████| 1395/1395 [00:11<00:00, 119.88it/s]

Epoch 5 completed. Test ROCAUC: 0.9210656989365524
Starting epoch 6



100%|██████████| 4458/4458 [01:24<00:00, 52.62it/s]


Training loss after epoch: 0.08629379985929134


100%|██████████| 1395/1395 [00:13<00:00, 106.32it/s]

Epoch 6 completed. Test ROCAUC: 0.9244940437710598
Starting epoch 7



100%|██████████| 4458/4458 [01:24<00:00, 52.49it/s]

Training loss after epoch: 0.06143969555251655



100%|██████████| 1395/1395 [00:13<00:00, 106.69it/s]

Epoch 7 completed. Test ROCAUC: 0.9260182188482033
Starting epoch 8



100%|██████████| 4458/4458 [01:31<00:00, 48.56it/s]

Training loss after epoch: 0.05725323206387607



100%|██████████| 1395/1395 [00:12<00:00, 110.14it/s]

Epoch 8 completed. Test ROCAUC: 0.9140568447218541
Starting epoch 9



100%|██████████| 4458/4458 [01:26<00:00, 51.56it/s]

Training loss after epoch: 0.04323065334335722



100%|██████████| 1395/1395 [00:12<00:00, 114.52it/s]

Epoch 9 completed. Test ROCAUC: 0.9319595637325657
Starting epoch 10



100%|██████████| 4458/4458 [01:24<00:00, 52.54it/s]

Training loss after epoch: 0.038179170969390046



100%|██████████| 1395/1395 [00:12<00:00, 114.04it/s]

Epoch 10 completed. Test ROCAUC: 0.9295884864450991
Starting epoch 11



100%|██████████| 4458/4458 [01:25<00:00, 52.23it/s]

Training loss after epoch: 0.024471788228614674



100%|██████████| 1395/1395 [00:11<00:00, 123.35it/s]

Epoch 11 completed. Test ROCAUC: 0.9232091581489172
Starting epoch 12



100%|██████████| 4458/4458 [01:25<00:00, 52.36it/s]

Training loss after epoch: 0.02008014395518283



100%|██████████| 1395/1395 [00:11<00:00, 123.38it/s]

Epoch 12 completed. Test ROCAUC: 0.9296202789602391
Starting epoch 13



100%|██████████| 4458/4458 [01:34<00:00, 47.24it/s]

Training loss after epoch: 0.02251831791846583



100%|██████████| 1395/1395 [00:11<00:00, 122.66it/s]

Epoch 13 completed. Test ROCAUC: 0.9346136538159839
Starting epoch 14



100%|██████████| 4458/4458 [01:27<00:00, 51.00it/s]

Training loss after epoch: 0.019235942480142765



100%|██████████| 1395/1395 [00:12<00:00, 108.80it/s]

Epoch 14 completed. Test ROCAUC: 0.9320154107401194
Starting epoch 15



100%|██████████| 4458/4458 [01:25<00:00, 52.35it/s]

Training loss after epoch: 0.017530265460588203



100%|██████████| 1395/1395 [00:11<00:00, 124.18it/s]

Epoch 15 completed. Test ROCAUC: 0.9334076598370428





# Adding Example

In [8]:
data_train = pd.read_pickle(f'experiments/data/adding_200_train.pkl')
data_test = pd.read_pickle(f'experiments/data/adding_200_test.pkl')
max_seq_len = max(max(data_train['len']), max(data_test['len']))

In [9]:
data_train.head()

Unnamed: 0,sequence,label,len,bin
13414,"[[-0.08982194422382483, 0.0], [-0.841596420741...",0.241234,476,6
56779,"[[-0.6032521601163512, 0.0], [-0.3534035945595...",0.557188,296,4
11976,"[[0.6875824449807342, 0.0], [0.732433410329577...",0.48989,205,2
29866,"[[-0.11423020701527808, 0.0], [-0.743413113459...",0.169627,178,1
1418,"[[-0.9588989728128303, 0.0], [-0.0150490739209...",0.605129,186,2


In [10]:
torch.cuda.set_device(1)
device = 'cuda:{}'.format(1) if torch.cuda.is_available() else 'cpu'
print('set up device:', device)

net = ChordMixerNet(
    problem='adding',
    vocab_size=1,
    max_seq_len=max_seq_len,
    embedding_size=196,
    track_size=16,
    hidden_size=128,
    mlp_dropout=0.0,
    layer_dropout=0.0,
    n_class=1
)
net = net.to(device)
net.apply(init_weights)

loss = nn.MSELoss()

optimizer = optim.Adam(
    net.parameters(),
    lr=0.0003
)

set up device: cuda:1


In [11]:
trainset = DatasetCreator(
    df=data_train,
    batch_size=20,
    var_len=True
)

trainloader = DataLoader(
    trainset,
    batch_size=20,
    shuffle=False,
    collate_fn=concater_collate,
    drop_last=False,
    num_workers=4
)


# Prepare the testing loader
testset = DatasetCreator(
    df=data_test,
    batch_size=20,
    var_len=True
)

testloader = DataLoader(
    testset,
    batch_size=20,
    shuffle=False,
    collate_fn=concater_collate,
    drop_last=False,
    num_workers=4
)


In [12]:
import numpy as np
def accuracy_adding(predictions, target):
    predictions = np.array(predictions)
    target = np.array(target)
    score = (np.abs(predictions - target) < 0.04).mean()
    return score

for epoch in range(15):
    print(f'Starting epoch {epoch+1}')
    config=None
    train_epoch(config, net, optimizer, loss, trainloader, device=device, log_every=100000, problem='adding')
    accuracy = eval_model(config, net, testloader, metric=accuracy_adding, device=device, problem='adding') 
    print(f'Epoch {epoch+1} completed. Test accuracy: {accuracy}')

Starting epoch 1


100%|██████████| 2493/2493 [03:52<00:00, 10.74it/s]

Training loss after epoch: 0.03848127465388064



100%|██████████| 256/256 [00:14<00:00, 17.32it/s]

Epoch 1 completed. Test accuracy: 0.3056640625
Starting epoch 2



100%|██████████| 2493/2493 [03:29<00:00, 11.92it/s]

Training loss after epoch: 0.008829407858169375



100%|██████████| 256/256 [00:11<00:00, 22.82it/s]


Epoch 2 completed. Test accuracy: 0.3947265625
Starting epoch 3


100%|██████████| 2493/2493 [03:24<00:00, 12.19it/s]

Training loss after epoch: 0.005712602193938677



100%|██████████| 256/256 [00:11<00:00, 23.04it/s]

Epoch 3 completed. Test accuracy: 0.6291015625
Starting epoch 4



100%|██████████| 2493/2493 [03:24<00:00, 12.19it/s]

Training loss after epoch: 0.003492843373081605



100%|██████████| 256/256 [00:14<00:00, 17.29it/s]

Epoch 4 completed. Test accuracy: 0.652734375
Starting epoch 5



100%|██████████| 2493/2493 [03:24<00:00, 12.17it/s]

Training loss after epoch: 0.0024492376346962356



100%|██████████| 256/256 [00:10<00:00, 23.63it/s]

Epoch 5 completed. Test accuracy: 0.6646484375
Starting epoch 6



100%|██████████| 2493/2493 [03:38<00:00, 11.42it/s]

Training loss after epoch: 0.002106257776262887



100%|██████████| 256/256 [00:10<00:00, 23.57it/s]

Epoch 6 completed. Test accuracy: 0.7484375
Starting epoch 7



100%|██████████| 2493/2493 [03:30<00:00, 11.84it/s]

Training loss after epoch: 0.0015214801763509238



100%|██████████| 256/256 [00:10<00:00, 23.80it/s]

Epoch 7 completed. Test accuracy: 0.80546875
Starting epoch 8



100%|██████████| 2493/2493 [03:27<00:00, 12.02it/s]

Training loss after epoch: 0.0012427910689068503



100%|██████████| 256/256 [00:10<00:00, 23.38it/s]

Epoch 8 completed. Test accuracy: 0.8611328125
Starting epoch 9



100%|██████████| 2493/2493 [03:31<00:00, 11.76it/s]

Training loss after epoch: 0.0009323694252113706



100%|██████████| 256/256 [00:11<00:00, 22.86it/s]

Epoch 9 completed. Test accuracy: 0.9064453125
Starting epoch 10



100%|██████████| 2493/2493 [03:24<00:00, 12.19it/s]

Training loss after epoch: 0.0007226888412659359



100%|██████████| 256/256 [00:10<00:00, 23.59it/s]

Epoch 10 completed. Test accuracy: 0.9193359375
Starting epoch 11



100%|██████████| 2493/2493 [03:27<00:00, 12.00it/s]

Training loss after epoch: 0.0005740218637673639



100%|██████████| 256/256 [00:10<00:00, 23.76it/s]

Epoch 11 completed. Test accuracy: 0.9494140625
Starting epoch 12



100%|██████████| 2493/2493 [03:25<00:00, 12.13it/s]

Training loss after epoch: 0.0006029668948928487



100%|██████████| 256/256 [00:14<00:00, 17.77it/s]

Epoch 12 completed. Test accuracy: 0.9611328125
Starting epoch 13



100%|██████████| 2493/2493 [03:28<00:00, 11.94it/s]

Training loss after epoch: 0.0004457904522805029



100%|██████████| 256/256 [00:14<00:00, 17.38it/s]

Epoch 13 completed. Test accuracy: 0.9724609375
Starting epoch 14



100%|██████████| 2493/2493 [04:00<00:00, 10.37it/s]

Training loss after epoch: 0.00040425269517010574



100%|██████████| 256/256 [00:13<00:00, 19.28it/s]

Epoch 14 completed. Test accuracy: 0.97265625
Starting epoch 15



100%|██████████| 2493/2493 [03:31<00:00, 11.80it/s]

Training loss after epoch: 0.00036565818400111484



100%|██████████| 256/256 [00:10<00:00, 23.77it/s]


Epoch 15 completed. Test accuracy: 0.9669921875
