In [1]:
%load_ext autoreload
%autoreload 2
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from utils import *
plt.style.use('ggplot')

In [31]:
import torch
from torch.utils.data import Dataset
import numpy as np

class PolymerDataset(Dataset):
    def __init__(self, data_paths, timesteps=100) -> None:
        self.raw_data = [np.load(data_path, allow_pickle=True) for data_path in data_paths]
        self.prepare(timesteps=timesteps)
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]
    
    def _process_event(self, event, timesteps=100, diff_threshold=0):
        compressed_event = []
        step_size = int(np.ceil(len(event) / timesteps))
        for i in range(timesteps):
            sub_event = event[i*step_size:(i+1)*step_size]
            features = build_features(sub_event, diff_threshold=diff_threshold)
            compressed_event.append(np.array(list(features.values())))
        return np.array(compressed_event)

    def prepare(self, timesteps=100, diff_threshold=0):
        data = []
        labels = []

        for data_index, raw_data in enumerate(self.raw_data):
            for event in raw_data:
                processed_event = self._process_event(event, timesteps=timesteps, diff_threshold=diff_threshold)
                data.append(processed_event)
                labels.append(data_index)

        self.data = torch.tensor(np.array(data), dtype=torch.float)
        self.labels = torch.tensor(np.array(labels), dtype=torch.long)
        return self

In [32]:
dataset = PolymerDataset(['data/AA66266AA.npy', 'data/AA66466AA.npy', 'data/AA66566AA.npy'], timesteps=64)

In [33]:
dataset.data.shape, dataset.labels.shape

(torch.Size([140153, 64, 8]), torch.Size([140153]))

In [36]:
from torch.utils.data import random_split
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_data, test_data = random_split(dataset, [train_size, test_size])

In [51]:
(dataset.labels == 0).sum(), (dataset.labels == 1).sum(), (dataset.labels == 2).sum()

(tensor(22039), tensor(75040), tensor(43074))

In [63]:
from torch.utils.data import DataLoader

class PolymerLSTM(torch.nn.Module):
    def __init__(self, num_features, num_classes, hidden_size=32) -> None:
        super().__init__()
        self.lstm = torch.nn.LSTM(input_size=num_features, hidden_size=hidden_size, batch_first=True)
        self.linear = torch.nn.Linear(hidden_size, num_classes)
    
    def forward(self, X):
        lstm_out, _ = self.lstm(X)
        outputs = lstm_out[:, -1, :]
        outputs = self.linear(outputs)
        probs = torch.nn.functional.log_softmax(outputs, dim=1)
        return probs
    
    def predict(self, X):
        probs = self.forward(X)
        preds = torch.argmax(probs, dim=1, keepdim=False)
        return preds


def train(dataset, num_epochs=100, batch_size=64, num_features=2, num_classes=2):
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    model = PolymerLSTM(num_features, num_classes)
    loss_function = torch.nn.NLLLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

    for epoch in range(num_epochs):
        num_correct = 0
        for X, y in iter(data_loader):
            model.zero_grad()
            probs = model(X)
            loss = loss_function(probs, y)
            loss.backward()
            optimizer.step()
            preds = torch.argmax(probs, dim=1, keepdim=False)
            num_correct += (preds == y).sum()
        print(f'epoch={epoch}/{num_epochs}, loss={loss}, accuracy={num_correct*100/len(dataset)}')
    
    return model

In [65]:
model = train(train_data, num_features=dataset.data.shape[2], num_classes=3, batch_size=128)

epoch=0/100, loss=0.8802793622016907, accuracy=55.40215301513672
epoch=1/100, loss=0.9036217331886292, accuracy=57.134193420410156
epoch=2/100, loss=0.8428705334663391, accuracy=57.650596618652344
epoch=3/100, loss=0.8184763193130493, accuracy=58.44169616699219
epoch=4/100, loss=1.013095498085022, accuracy=59.25331497192383
epoch=5/100, loss=0.7359552979469299, accuracy=60.13449478149414
epoch=6/100, loss=0.8546602129936218, accuracy=61.260948181152344
epoch=7/100, loss=0.811974287033081, accuracy=61.978023529052734
epoch=8/100, loss=0.7188849449157715, accuracy=62.30088806152344
epoch=9/100, loss=0.7336299419403076, accuracy=62.979610443115234
epoch=10/100, loss=0.7796737551689148, accuracy=63.37025833129883
epoch=11/100, loss=0.7835960388183594, accuracy=63.62444305419922
epoch=12/100, loss=0.7308964729309082, accuracy=63.96603775024414
epoch=13/100, loss=0.7248246669769287, accuracy=64.26214599609375
epoch=14/100, loss=0.7424956560134888, accuracy=64.52168273925781
epoch=15/100, los

KeyboardInterrupt: 