In [1]:
import os
import glob
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from ast import literal_eval
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random
from torch.autograd import Variable
import gc

In [2]:
torch.cuda.set_device(1) 

In [3]:
class Dataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.labels = pd.read_csv(csv_file).iloc[:]
        self.root_dir = root_dir
        self.transform = transform
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        feature_path = os.path.join(self.root_dir, self.labels.iloc[idx,0])
        embeddings = torch.load(feature_path)
        features = embeddings[:, :, :768]
        maps = embeddings[:, :, 768:]
        rmsds = [float(label) for label in literal_eval(self.labels.iloc[idx, 1])]
        if self.transform:
            embedding = self.transform(embedding)
        return features, maps, rmsds, feature_path
    
dataset = Dataset(csv_file='/home/vera/projects/masters_project/data/rmsd_dataset.csv',
                                    root_dir='/home/vera/projects/masters_project/data/s-pred_features/')

## Load in the dataset

In [6]:
# Split dataset into train, validation and test sets
train_size = int(0.8 * len(dataset))
valid_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - valid_size

train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, valid_size, test_size])

# Test if the dataset is split correctly
print(len(train_dataset), len(valid_dataset), len(test_dataset))

# Create the dataloaders
batch_size = 1 
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

# Test if the dataloaders are working correctly
for i, (features, maps, rmsds, feature_path) in enumerate(train_loader):
    print(features.shape)
    print(maps.shape)
    print(len(rmsds))
    print(feature_path)
    break


8304 1038 1038
torch.Size([1, 1, 164, 768])
torch.Size([1, 1, 164, 288])
164
('/home/vera/projects/masters_project/data/s-pred_features/embeddings_5cuh_a.pt',)


In [7]:
batch_size = 4
# Sort the data by length of the embeddings
train_dataset = sorted(train_dataset, key=lambda x: x[0].shape[1], reverse=True)

# Go through the data and create batches of embeddings with the same length
train_batches = []

for i in range(0, len(train_dataset), batch_size):
    batch = train_dataset[i:i+batch_size]
    # Check if all embeddings in the batch have the same length
    if len(set([embedding[0].shape[1] for embedding in batch])) != 1:
        continue
    train_batches.append(batch)
# Shuffle the batches
random.shuffle(train_batches)

# Test if the batches are created correctly
for i, batch in enumerate(train_batches):
    print('Batch ' + str(i))
    for features, maps, rmsds, feature_path in batch:
        print(features.shape)
        print(maps.shape)
        print(len(rmsds))
        print(feature_path)
    break

# Create data loaders for the batches
batch_size = 4
train_loader = DataLoader(train_batches)

# Test if the dataloaders are working correctly
for i, batch in enumerate(train_loader):
    print('Batch ' + str(i))
    for features, maps, rmsds, feature_path in batch:
        print(features.shape)
        print(maps.shape)
        print(len(rmsds))
        print(feature_path)
    break

print(len(train_batches))

KeyboardInterrupt: 

## Create the model

In [8]:
# Create the LSTM model
class lstm_net(nn.Module):

    def __init__(self, input_feature_size=768, hidden_node=256, dropout=0.25, class_num=8):
        super(lstm_net, self).__init__()

        self.linear_proj = nn.Sequential(
            nn.Linear(input_feature_size, input_feature_size // 2),
            nn.InstanceNorm1d(input_feature_size // 2),
            nn.ReLU(),
            nn.Linear(input_feature_size // 2, input_feature_size // 4),
            nn.InstanceNorm1d(input_feature_size // 4),
            nn.ReLU(),
            nn.Linear(input_feature_size // 4, input_feature_size // 4),
        )

        lstm_input_feature_size = input_feature_size // 4 + 144*2

        self.lstm = nn.LSTM(
            input_size=lstm_input_feature_size,
            hidden_size=hidden_node,
            num_layers=2,
            bidirectional=True,
            dropout=dropout,
            batch_first=True,
        )

        self.to_property = nn.Sequential(
            nn.Linear(hidden_node * 2, hidden_node * 2),
            nn.InstanceNorm1d(hidden_node * 2),
            nn.ReLU(),
            nn.Linear(hidden_node * 2, class_num),
        )

    def forward(self, msa_query_embeddings, msa_attention_features):
        msa_query_embeddings = self.linear_proj(msa_query_embeddings)

        lstm_input = torch.cat([msa_query_embeddings, msa_attention_features], dim=2)
        lstm_input = lstm_input.permute((1, 0, 2))

        lstm_output, lstm_hidden = self.lstm(lstm_input)
        lstm_output = lstm_output.permute((1, 0, 2))
        
        label_output = self.to_property(lstm_output)

        return label_output

## Train the model

In [140]:
EPOCHS = 10
LEARNING_RATE = 0.001
WEIGHT_DECAY = 0
BATCH_SIZE = 1
HIDDEN_NODE = 256
DROPOUT = 0
CLASS_NUM = 1
NUM_ACCUMULATION_STEPS = 2

RMSD_THRESHOLD = 1

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)


model = lstm_net(input_feature_size=768, hidden_node=HIDDEN_NODE, dropout=DROPOUT, class_num=CLASS_NUM)
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model = model.to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

train_loss_list = []
valid_loss_list = []

model.train()

for epoch in range(1, EPOCHS+1):
    model.zero_grad()
    train_loss = 0
    valid_loss = 0

    for i, (features, maps, rmsds, feature_path) in enumerate(train_loader):
        features = features[0,:,:,:].to(device)
        maps = maps[0,:,:,:].to(device)
        labels = torch.Tensor([int(label >= RMSD_THRESHOLD) for label in rmsds]).to(device)
        optimizer.zero_grad()

        output = model(msa_query_embeddings=features, msa_attention_features=maps)
        loss = criterion(output[0,:,0], labels)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    train_loss_list.append(train_loss / len(train_loader))

    model.eval()
    with torch.no_grad():
        for i, (features, maps, rmsds, feature_path) in enumerate(valid_loader):
            features = features[0,:,:,:].to(device)
            maps = maps[0,:,:,:].to(device)
            labels = torch.Tensor([int(label >= RMSD_THRESHOLD) for label in rmsds]).to(device)

            output = model(msa_query_embeddings=features, msa_attention_features=maps)
            loss = criterion(output[0,:,0], labels)

            valid_loss += loss.item()

        valid_loss_list.append(valid_loss / len(valid_loader))

    print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
        epoch, train_loss_list[-1], valid_loss_list[-1]))
    
    model.train()

TypeError: list indices must be integers or slices, not tuple