In [1]:
import os
import math
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from protein_dataset import ProteinDataset
import constants as CONSTANTS
from models.basic_vae_1 import BasicVAE1
from models.vae_loss import VAELoss


%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = BasicVAE1()
model.to(device)
criterion = VAELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
batch_size = 30
n_epochs = 10
print_every = 2
test_every = 2
plot_every = 2

In [3]:
train_dataset = ProteinDataset(CONSTANTS.TRAIN_FILE)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
print(train_dataset.__len__())
x, y = train_dataset.__getitem__(0)
print(x.shape, y.shape)
len(train_loader)

1n1f: 1-hot size: torch.Size([153, 20]) contact-map size: torch.Size([153, 153])
1msc: 1-hot size: torch.Size([129, 20]) contact-map size: torch.Size([129, 129])
1f2t: 1-hot size: torch.Size([288, 20]) contact-map size: torch.Size([288, 288])
1qr0: 1-hot size: torch.Size([228, 20]) contact-map size: torch.Size([228, 228])
1uv4: 1-hot size: torch.Size([291, 20]) contact-map size: torch.Size([291, 291])
1j30: 1-hot size: torch.Size([278, 20]) contact-map size: torch.Size([278, 278])
1o4v: 1-hot size: torch.Size([169, 20]) contact-map size: torch.Size([169, 169])
1nko: 1-hot size: torch.Size([122, 20]) contact-map size: torch.Size([122, 122])
1nfp: 1-hot size: torch.Size([228, 20]) contact-map size: torch.Size([228, 228])
1ctf: 1-hot size: torch.Size([68, 20]) contact-map size: torch.Size([68, 68])
1dk8: 1-hot size: torch.Size([147, 20]) contact-map size: torch.Size([147, 147])
1pch: 1-hot size: torch.Size([88, 20]) contact-map size: torch.Size([88, 88])
1reg: 1-hot size: torch.Size([242,

1ihj: 1-hot size: torch.Size([199, 20]) contact-map size: torch.Size([199, 199])
1fm0: 1-hot size: torch.Size([223, 20]) contact-map size: torch.Size([223, 223])
1doi: 1-hot size: torch.Size([128, 20]) contact-map size: torch.Size([128, 128])
1xer: 1-hot size: torch.Size([102, 20]) contact-map size: torch.Size([102, 102])
1s67: 1-hot size: torch.Size([234, 20]) contact-map size: torch.Size([234, 234])
1lmb: 1-hot size: torch.Size([179, 20]) contact-map size: torch.Size([179, 179])
1owf: 1-hot size: torch.Size([190, 20]) contact-map size: torch.Size([190, 190])
1gui: 1-hot size: torch.Size([155, 20]) contact-map size: torch.Size([155, 155])
1aly: 1-hot size: torch.Size([146, 20]) contact-map size: torch.Size([146, 146])
1ucr: 1-hot size: torch.Size([149, 20]) contact-map size: torch.Size([149, 149])
1poc: 1-hot size: torch.Size([134, 20]) contact-map size: torch.Size([134, 134])
1pgx: 1-hot size: torch.Size([70, 20]) contact-map size: torch.Size([70, 70])
1tua: 1-hot size: torch.Size([1

1nu0: 1-hot size: torch.Size([244, 20]) contact-map size: torch.Size([244, 244])
1q3f: 1-hot size: torch.Size([223, 20]) contact-map size: torch.Size([223, 223])
1di2: 1-hot size: torch.Size([129, 20]) contact-map size: torch.Size([129, 129])
1n7s: 1-hot size: torch.Size([276, 20]) contact-map size: torch.Size([276, 276])
1iap: 1-hot size: torch.Size([190, 20]) contact-map size: torch.Size([190, 190])
1jo0: 1-hot size: torch.Size([193, 20]) contact-map size: torch.Size([193, 193])
1dfu: 1-hot size: torch.Size([94, 20]) contact-map size: torch.Size([94, 94])
1hh8: 1-hot size: torch.Size([192, 20]) contact-map size: torch.Size([192, 192])
1eh6: 1-hot size: torch.Size([168, 20]) contact-map size: torch.Size([168, 168])
1lo7: 1-hot size: torch.Size([140, 20]) contact-map size: torch.Size([140, 140])
1gmx: 1-hot size: torch.Size([107, 20]) contact-map size: torch.Size([107, 107])
1o54: 1-hot size: torch.Size([265, 20]) contact-map size: torch.Size([265, 265])
1cdw: 1-hot size: torch.Size([1

1m1h: 1-hot size: torch.Size([182, 20]) contact-map size: torch.Size([182, 182])
1nn5: 1-hot size: torch.Size([204, 20]) contact-map size: torch.Size([204, 204])
1vr9: 1-hot size: torch.Size([234, 20]) contact-map size: torch.Size([234, 234])
1x6i: 1-hot size: torch.Size([176, 20]) contact-map size: torch.Size([176, 176])
1whz: 1-hot size: torch.Size([67, 20]) contact-map size: torch.Size([67, 67])
1a34: 1-hot size: torch.Size([147, 20]) contact-map size: torch.Size([147, 147])
1x8q: 1-hot size: torch.Size([184, 20]) contact-map size: torch.Size([184, 184])
1alu: 1-hot size: torch.Size([157, 20]) contact-map size: torch.Size([157, 157])
1wtr: 1-hot size: torch.Size([66, 20]) contact-map size: torch.Size([66, 66])
1ng2: 1-hot size: torch.Size([176, 20]) contact-map size: torch.Size([176, 176])
1q8c: 1-hot size: torch.Size([132, 20]) contact-map size: torch.Size([132, 132])
1d9c: 1-hot size: torch.Size([240, 20]) contact-map size: torch.Size([240, 240])
1rlh: 1-hot size: torch.Size([151,

1vr8: 1-hot size: torch.Size([132, 20]) contact-map size: torch.Size([132, 132])
1rfs: 1-hot size: torch.Size([127, 20]) contact-map size: torch.Size([127, 127])
1ls1: 1-hot size: torch.Size([289, 20]) contact-map size: torch.Size([289, 289])
1i6j: 1-hot size: torch.Size([256, 20]) contact-map size: torch.Size([256, 256])
1ly1: 1-hot size: torch.Size([152, 20]) contact-map size: torch.Size([152, 152])
1f7d: 1-hot size: torch.Size([235, 20]) contact-map size: torch.Size([235, 235])
1wv3: 1-hot size: torch.Size([181, 20]) contact-map size: torch.Size([181, 181])
1a1k: 1-hot size: torch.Size([85, 20]) contact-map size: torch.Size([85, 85])
1pji: 1-hot size: torch.Size([267, 20]) contact-map size: torch.Size([267, 267])
1ujc: 1-hot size: torch.Size([156, 20]) contact-map size: torch.Size([156, 156])
1jyh: 1-hot size: torch.Size([155, 20]) contact-map size: torch.Size([155, 155])
1whi: 1-hot size: torch.Size([122, 20]) contact-map size: torch.Size([122, 122])
1ha1: 1-hot size: torch.Size([1

1jfx: 1-hot size: torch.Size([217, 20]) contact-map size: torch.Size([217, 217])
1v30: 1-hot size: torch.Size([118, 20]) contact-map size: torch.Size([118, 118])
1wka: 1-hot size: torch.Size([143, 20]) contact-map size: torch.Size([143, 143])
1fjl: 1-hot size: torch.Size([182, 20]) contact-map size: torch.Size([182, 182])
1jg1: 1-hot size: torch.Size([215, 20]) contact-map size: torch.Size([215, 215])
1bte: 1-hot size: torch.Size([186, 20]) contact-map size: torch.Size([186, 186])
1wvh: 1-hot size: torch.Size([132, 20]) contact-map size: torch.Size([132, 132])
1wlu: 1-hot size: torch.Size([117, 20]) contact-map size: torch.Size([117, 117])
1qau: 1-hot size: torch.Size([112, 20]) contact-map size: torch.Size([112, 112])
1gk7: 1-hot size: torch.Size([39, 20]) contact-map size: torch.Size([39, 39])
1wy3: 1-hot size: torch.Size([34, 20]) contact-map size: torch.Size([34, 34])
1jmw: 1-hot size: torch.Size([146, 20]) contact-map size: torch.Size([146, 146])
1gl2: 1-hot size: torch.Size([229,

5028

In [4]:
val_dataset = ProteinDataset(CONSTANTS.VAL_FILE)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
print(len(val_loader))

1sfu: 1-hot size: torch.Size([140, 20]) contact-map size: torch.Size([140, 140])
1a27: 1-hot size: torch.Size([285, 20]) contact-map size: torch.Size([285, 285])
1gk6: 1-hot size: torch.Size([107, 20]) contact-map size: torch.Size([107, 107])
1pdo: 1-hot size: torch.Size([129, 20]) contact-map size: torch.Size([129, 129])
1mgt: 1-hot size: torch.Size([169, 20]) contact-map size: torch.Size([169, 169])
1mhn: 1-hot size: torch.Size([59, 20]) contact-map size: torch.Size([59, 59])
1py9: 1-hot size: torch.Size([116, 20]) contact-map size: torch.Size([116, 116])
1o50: 1-hot size: torch.Size([141, 20]) contact-map size: torch.Size([141, 141])
1j0p: 1-hot size: torch.Size([108, 20]) contact-map size: torch.Size([108, 108])
1ng6: 1-hot size: torch.Size([148, 20]) contact-map size: torch.Size([148, 148])
1efd: 1-hot size: torch.Size([262, 20]) contact-map size: torch.Size([262, 262])
1tkj: 1-hot size: torch.Size([277, 20]) contact-map size: torch.Size([277, 277])
1r75: 1-hot size: torch.Size([1

1eaq: 1-hot size: torch.Size([243, 20]) contact-map size: torch.Size([243, 243])
1euv: 1-hot size: torch.Size([300, 20]) contact-map size: torch.Size([300, 300])
1dyp: 1-hot size: torch.Size([266, 20]) contact-map size: torch.Size([266, 266])
1r0u: 1-hot size: torch.Size([142, 20]) contact-map size: torch.Size([142, 142])
1rl6: 1-hot size: torch.Size([164, 20]) contact-map size: torch.Size([164, 164])
1q33: 1-hot size: torch.Size([287, 20]) contact-map size: torch.Size([287, 287])
1kve: 1-hot size: torch.Size([280, 20]) contact-map size: torch.Size([280, 280])
1hq0: 1-hot size: torch.Size([295, 20]) contact-map size: torch.Size([295, 295])
1u84: 1-hot size: torch.Size([81, 20]) contact-map size: torch.Size([81, 81])
1tvg: 1-hot size: torch.Size([136, 20]) contact-map size: torch.Size([136, 136])
1vaj: 1-hot size: torch.Size([201, 20]) contact-map size: torch.Size([201, 201])
1prz: 1-hot size: torch.Size([242, 20]) contact-map size: torch.Size([242, 242])
1rfe: 1-hot size: torch.Size([1

In [None]:
test_dataset = ProteinDataset(CONSTANTS.TEST_FILE)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
print(len(test_loader))

In [5]:
def train():
    model.train()
    loss = 0.0
    losses = []
    for i, (x, y) in enumerate(train_loader):
        x, y = x.to(device), y.to(device)
#         print("x:", x.shape, "y:", y.shape)
        optimizer.zero_grad()
        y_prime, mu, logvar = model(x)
        # y_prime.squeeze_(0)
#         print("y_prime:", y_prime.size(), "y:", y.size())
        loss = criterion(y, y_prime, mu, logvar)
        # print(loss)
        loss.backward()
        optimizer.step()
        losses.append(loss)
    return torch.stack(losses).mean().item()
        

In [9]:
train()

8814.5595703125

In [6]:
def test(data_loader):
    model.eval()
    loss = 0.0
    losses = []
    for i, (x, y) in enumerate(data_loader):
        x, y = x.to(device), y.to(device)
        y_prime, mu, logvar = model(x)
        loss = criterion(y, y_prime, mu, logvar)
        losses.append(loss)
    return torch.stack(losses).mean().item()

In [7]:
train_losses = []
val_losses = []
best_test_loss = np.inf
for epoch in range(1, n_epochs+1):
    train_loss = train()
    train_losses.append(train_loss)
    
    if epoch % print_every == 0:
        print("epoch:{}/{}, train_loss: {:.5f}".format(epoch, n_epochs+1, train_loss))
    
    if epoch % test_every == 0:
        val_loss = test(val_loader)
        print("epoch:{}/{}, val_loss: {:.5f}".format(epoch, n_epochs+1, val_loss))
        val_losses.append(val_loss)
        if val_loss < best_test_loss:
            best_test_loss = val_loss
            print('Updating best test loss: {:.5f}'.format(best_test_loss))
            torch.save(model.state_dict(),'../outputs/best_model.pth')
    
    if epoch % plot_every == 0:
        plt.plot(train_losses)
        plt.plot(val_losses)
        plt.show()



RuntimeError: reduce failed to synchronize: cudaErrorAssert: device-side assert triggered

In [None]:
test_loss = test(test_loader)
test_loss