In [1]:
!pip install fair-esm

Collecting fair-esm
  Downloading fair_esm-2.0.0-py3-none-any.whl (93 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fair-esm
Successfully installed fair-esm-2.0.0
[0m

In [2]:
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch import normal
from torch.nn.modules.linear import Linear
import os
import torch
import copy
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
from Bio import SeqIO
import esm
from numpy.lib.function_base import average
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import torch.optim as optim

antibiotic_fam = {'aminoglycoside': 0, 'macrolide-lincosamide-streptogramin': 1, 'polymyxin': 2, 
            'fosfomycin': 3, 'multidrug': 7, 'bacitracin': 5, 'quinolone': 6, 
            'trimethoprim': 4, 'chloramphenicol': 8,'tetracycline': 9, 'rifampin': 10, 
            'beta_lactam': 11, 'sulfonamide': 12, 'glycopeptide': 13, 'nonarg':14}

def read(file_path):
    data = []
    for item in SeqIO.parse(file_path, "fasta"):
        if 'FEATURE' in item.id:
            anti_class = item.description.split('|')[3]
            y = antibiotic_fam[anti_class]
        else:
            y = antibiotic_fam['nonarg']
        X = item.seq
        data.append((y, X))
    return data

def read_test(file_path):
    data = []
    name_list = []
    for item in SeqIO.parse(file_path, "fasta"):
        X = item.seq
        y = -1
        data.append((y, X))
        name_list.append(item.id)
    return data, name_list

data_dir = "/kaggle/input/aist4010-spring2023-a2/data"


train_data = read(os.path.join(data_dir, "train.fasta"))
val_data = read(os.path.join(data_dir, "val.fasta"))
test_data, test_list = read_test(os.path.join(data_dir, "test.fasta"))

if torch.cuda.is_available(): 
     device = "cuda" 
else: 
     device = "cpu"

In [3]:
class Net(nn.Module):
    def __init__(self, num_classes = 15):
        super(Net, self).__init__()
        self.CNN = nn.Sequential(
                                nn.Conv1d(in_channels=128, out_channels=160, kernel_size=9),
                                nn.ReLU(inplace=True),
                                nn.MaxPool1d(kernel_size=2, stride=2),
                                nn.Conv1d(in_channels=160, out_channels=240, kernel_size=5),
                                nn.ReLU(inplace=True),
                                nn.MaxPool1d(kernel_size=2, stride=2),
                                 nn.Conv1d(in_channels=240, out_channels=240, kernel_size=5),
                                nn.ReLU(inplace=True),
            
            
        )
        self.classifier = nn.Sequential(
                                        nn.Linear(1280, 1024),
                                        nn.Sigmoid(),
                                        nn.Linear(1024, 1024),
                                        nn.Sigmoid(),
                                        nn.Linear(1024, num_classes)
                                       )
    

        self._initialize_weights()
    
    def _initialize_weights(self):
        for m in self._modules:
            if isinstance(m, Linear):
                m.weight.data.normal_(0, 0.02)
                m.bias.data.zero_()

    def forward(self, input):
#         x = self.CNN(input)
        x = input.unsqueeze(1).float()
        x = x.view(-1, 1280)
        x = self.classifier(x)

        return x

In [4]:
# Load ESM model
# model, alphabet = esm.pretrained.esm1_t34_670M_UR50D()
# model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
# model = model.cuda()
# converter = alphabet.get_batch_converter()

In [5]:
# def embedding(data, converter):
#     representation = []
#     labels, strs, tokens = converter(data)
#     with torch.no_grad():
#         for i in tqdm(range(0,len(tokens),20)):
#             results = model(tokens[i:i+20].cuda(), repr_layers=[12], return_contacts=False)
#             for j in results["representations"][12]:
#                 representation.append(j.mean(0))
# #   Stack the tensor
#     Xs = torch.stack(representation, dim=0).cpu()
#     return Xs, np.array(labels)

# train_X, train_y = embedding(train_data, converter)
# train_data_esm = (train_X, train_y)
# torch.save(train_data_esm, 'train.pt')
# val_X, val_y  = embedding(val_data, converter)
# val_data_esm = (val_X, val_y)
# torch.save(val_data_esm, 'val.pt')

# test_X, _ = embedding(test_data, converter)
# torch.save(test_X, 'test.pt')

In [6]:
class Data(Dataset):
    def __init__(self, esm_data):
        self.Xs = esm_data[0]
        self.ys = esm_data[1]
            
    def __len__(self):
        return len(self.ys)

    def __getitem__(self, index):
        X = self.Xs[index]
        y = self.ys[index]
        return (X, y)

train_esm = torch.load('/kaggle/input/asm2-check-pt/aist4010 asm2/train.pt')
val_esm = torch.load('/kaggle/input/asm2-check-pt/aist4010 asm2/val.pt')
# train_esm = torch.load('/kaggle/working/train.pt')
# val_esm = torch.load('/kaggle/working/val.pt')

trainset = Data(train_esm)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True)
valset = Data(val_esm)
valloader = DataLoader(valset, batch_size=128, shuffle=False)

In [7]:
def train_model(model, criterion, optimizer,scheduler, num_epochs=10):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_f1 = 0.0
    for epoch in range(num_epochs):  # loop over the dataset multiple times
        y_true = []
        y_pred = []
        train_loss = 0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)
            optimizer.zero_grad()
            # forward 
            outputs = net(inputs)
            # backward
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            pred = outputs.data.argmax(1).cpu().numpy()
            y_pred.append(pred)
            true = labels.data.cpu().numpy()
            y_true.append(true)
           
        # train stat evaluation    
        y_true = np.concatenate(y_true , axis=0)
        y_pred = np.concatenate(y_pred,  axis=0)
        train_f1 = f1_score(y_true=y_true,y_pred=y_pred, average='macro')
        train_acc = accuracy_score(y_true,y_pred)
        

        #  Validation
        y_true = []
        y_pred = []
        val_loss = 0
        with torch.no_grad():
            for data in valloader:
                inputs, labels = data[0].to(device), data[1].to(device)
                # forward 
                outputs = net(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                
                pred = outputs.data.argmax(1).cpu().numpy()
                y_pred.append(pred)

                true = labels.data.cpu().numpy()
                y_true.append(true)

        #  val stat evaluation   
        y_true = np.concatenate(y_true , axis=0)
        y_pred = np.concatenate(y_pred,  axis=0)
        val_f1 = f1_score(y_true=y_true,y_pred=y_pred, average='macro')
        val_acc = accuracy_score(y_true,y_pred)
        scheduler.step()
        
        # Save the best param
        if val_f1 > best_f1:
            best_f1 = val_f1
            best_model_wts = copy.deepcopy(model.state_dict())
                
        # Print stat every 5 epoch
        if (epoch+1)%5 == 0:
            print(f'epoch: {epoch + 1}')
            print(f'train_loss: {train_loss:.3f} val_loss: {val_loss:.3f} train_f1: {train_f1:.2f} val_fl: {val_f1:.3f}')

    torch.save(net.state_dict(), './train_esm.pth')
#    Load the best param
    model.load_state_dict(best_model_wts)

    print(f'Best f1 score is {best_f1:.3f}')


In [8]:
from torch.optim import lr_scheduler
if __name__ == '__main__':            
    net = Net(num_classes=15)
    net.to(device)
    criterion = nn.CrossEntropyLoss()
    # optimizer = optim.SGD(net.parameters(), lr=0.005, momentum=0.9)
    optimizer = optim.Adam(net.parameters(), lr=0.001)

    lr_scheduler = lr_scheduler.StepLR(optimizer, step_size = 50, gamma=0.2)
    model = train_model(net, criterion, optimizer, lr_scheduler, num_epochs=200)

    # Load test data
#     test_esm = torch.load('/kaggle/working/test.pt')
    test_esm = torch.load('/kaggle/input/asm2-check-pt/aist4010 asm2/test.pt')
    label = []

    with torch.no_grad():
        for data in test_esm:
            outputs = net(data.cuda())
            _, predicted = torch.max(outputs.data, 1)
            label.append(predicted.cpu().numpy()[0])

    results = {'id': test_list, 'label': label}
    results_df = pd.DataFrame(results)
    results_df.to_csv('submission.csv', index=False)

epoch: 5
train_loss: 8.852 val_loss: 2.832 train_f1: 0.96 val_fl: 0.946
epoch: 10
train_loss: 4.829 val_loss: 2.792 train_f1: 0.97 val_fl: 0.956
epoch: 15
train_loss: 2.790 val_loss: 4.545 train_f1: 0.98 val_fl: 0.940
epoch: 20
train_loss: 3.292 val_loss: 3.506 train_f1: 0.98 val_fl: 0.943
epoch: 25
train_loss: 3.269 val_loss: 3.451 train_f1: 0.98 val_fl: 0.960
epoch: 30
train_loss: 2.172 val_loss: 3.574 train_f1: 0.98 val_fl: 0.963
epoch: 35
train_loss: 1.894 val_loss: 4.033 train_f1: 0.99 val_fl: 0.960
epoch: 40
train_loss: 2.010 val_loss: 3.751 train_f1: 0.98 val_fl: 0.966
epoch: 45
train_loss: 2.362 val_loss: 3.872 train_f1: 0.98 val_fl: 0.960
epoch: 50
train_loss: 1.975 val_loss: 3.694 train_f1: 0.99 val_fl: 0.962
epoch: 55
train_loss: 0.708 val_loss: 3.836 train_f1: 0.99 val_fl: 0.970
epoch: 60
train_loss: 0.704 val_loss: 4.045 train_f1: 0.99 val_fl: 0.971
epoch: 65
train_loss: 0.672 val_loss: 4.012 train_f1: 0.99 val_fl: 0.970
epoch: 70
train_loss: 0.620 val_loss: 4.162 train_f1