In [1]:
import os
import numpy as np
import torch
import pandas as pd
import pickle
import matplotlib
from collections import defaultdict, Counter
from models.Text1DCNN_model import TextCNN1d
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss
from torch.optim import SGD, Adam, AdamW, lr_scheduler
from utils import Logger, AverageMeter, save_checkpoint
from build_dataset import build_dataset
from training import train_epoch
from validation import validate_model


  from .autonotebook import tqdm as notebook_tqdm


ModuleNotFoundError: No module named 'build_dataset'

hyperparams

In [12]:
data_path = "datasets/HLA_CDR/merge_dataset_3k.pickle"
train_logger_path = "log/train_logger.log"
train_batch_logger_path = "log/train_batch_logger.log"
validation_logger_path = "log/validation_logger_path.log"
result_path = "log/save_models"

learning_rate = 0.1
batch_size = 16
epoches = 100
train_logger = Logger(train_logger_path, ['epoch','loss', 'HammingLoss', 'lr'])
train_batch_logger_path = Logger(train_batch_logger_path, ['epoch','batch','iter','loss','HammingLoss','lr'])
validation_logger = Logger(validation_logger_path, ['epoch','loss', 'HammingLoss'])

In [6]:
with open("datasets/HLA_CDR/merge_dataset_3k_TCRpeg.pickle", "rb") as f:
    data = pickle.load(f)

In [None]:
with open("datasets/HLA_CDR/means_and_stds_merge_dataset_3k.pickle", "rb") as f:
    means_and_stds = pickle.load(f)
mean = means_and_stds["means"]
std = means_and_stds["stds"]

In [4]:
train_dataset = build_dataset(data_path, 'training', mean=mean, std=std)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, pin_memory=True)
validation_dataset = build_dataset(data_path, 'validation', mean=mean, std=std)
validation_dataloader = DataLoader(validation_dataset, batch_size=batch_size, pin_memory=True)

In [13]:
train_dataset = build_dataset(data_path, 'training')

In [3]:
datasets = np.load("datasets/sum_datasets.npy", allow_pickle=True).item()

for sample, sample_value in datasets.items():
    new_hlas = []
    hlas = datasets[sample]["HLA"]
    for hla in hlas:
        if len(hla) == 6: #B*5201
            print(hla,sample)
            new_gene = hla[:4] + ":" +hla[4:]
            print(new_gene)
            new_hlas.append(new_gene)
        else:
            new_hlas.append(hla)
    datasets[sample]["HLA"] = new_hlas

np.save("datasets/sum_datasets_cp.npy", datasets)

A*0301 1349BW_unsorted_cc1000000_ImmunRACE_043020_003_gDNA_TCRB
A*03:01
A*6901 1349BW_unsorted_cc1000000_ImmunRACE_043020_003_gDNA_TCRB
A*69:01
B*0702 1349BW_unsorted_cc1000000_ImmunRACE_043020_003_gDNA_TCRB
B*07:02
B*0702 1349BW_unsorted_cc1000000_ImmunRACE_043020_003_gDNA_TCRB
B*07:02
A*0201 1588BW_20200417_PBMC_unsorted_cc1000000_ImmunRACE_050820_008_gDNA_TCRB
A*02:01
A*2301 1588BW_20200417_PBMC_unsorted_cc1000000_ImmunRACE_050820_008_gDNA_TCRB
A*23:01
B*0702 1588BW_20200417_PBMC_unsorted_cc1000000_ImmunRACE_050820_008_gDNA_TCRB
B*07:02
B*1801 1588BW_20200417_PBMC_unsorted_cc1000000_ImmunRACE_050820_008_gDNA_TCRB
B*18:01
A*0203 15_12weeks
A*02:03
A*1101 15_12weeks
A*11:01
B*1501 15_12weeks
B*15:01
B*1801 15_12weeks
B*18:01
A*0203 15_2weeks
A*02:03
A*1101 15_2weeks
A*11:01
B*1501 15_2weeks
B*15:01
B*1801 15_2weeks
B*18:01
A*0203 15_44weeks
A*02:03
A*1101 15_44weeks
A*11:01
B*1501 15_44weeks
B*15:01
B*1801 15_44weeks
B*18:01
A*0203 15_4weeks
A*02:03
A*1101 15_4weeks
A*11:01
B*1501 15_

In [4]:
model = Text1DCNN_model(input_dim=192)
model.to()
optimizer = Adam(model.parameters(), lr= learning_rate)
criterion= BCEWithLogitsLoss()

train

In [None]:
for epoch in np.arange(epoches):
    train_epoch(model, epoch, train_dataloader, criterion, optimizer, batch_logger=train_batch_logger_path, epoch_logger=train_logger)
    if epoch % 10 == 0:
        save_file_path = os.path.join(result_path, 'save_{}.pth'.format(epoch))
        save_checkpoint(save_file_path, epoch, model, optimizer)
    loss, accuracy = validate_model(model, epoch, validation_dataloader, criterion, logger=validation_logger)

In [None]:
validate_model(model, epoch, validation_dataloader, criterion, logger=validation_logger)

In [5]:
merge_data = pd.read_csv("datasets/HLA_CDR/new_ImmuneAccess.TRB.metedata.csv")
file_paths = merge_data.file_path.to_list()
datasets2 = defaultdict(dict)
amino_acids='[^ARNDCQEGHILKMFPSTWYV]'
for i, file_path in enumerate(file_paths):
    if os.path.exists(os.path.join("datasets/ImmuneAccess",file_path.split('/')[-1])):
        content = pd.read_csv(os.path.join("datasets/ImmuneAccess",file_path.split('/')[-1]), sep='\t')
    else:
        continue
    #clear data
    cdr3s = content[content.cdr3aa.str.isalpha()].cdr3aa.str.upper().tolist()
    filter_cdr3s = []
    for tcr in cdr3s:
        if len(re.sub(amino_acids,'',tcr)) == len(tcr):  #filter special Letter
            if len(tcr) <= 30: #filter length
                filter_cdr3s.append(tcr)
    
    sample_id = merge_data.iloc[i]["sample"]
    filter_hla = []
    for hla_type in ["HLA-A1","HLA-A2","HLA-B1","HLA-B2"]:
        hla = merge_data.iloc[i][hla_type]
        letter = hla[0] # A/B/C
        s = re.findall("\d+",hla)
        if len(s)>1:
            hla = letter + '*' + ":".join(s)
        elif len(s) == 1:
            hla = letter + '*' + s[0]
        filter_hla.append(hla)
    datasets2[sample_id]["HLA"] = filter_hla
    datasets2[sample_id]["TCRB"] = filter_cdr3s

In [3]:
datasets = np.load("datasets/sum_datasets.npy", allow_pickle=True).item()

In [5]:
tcrs = []
for sample, sample_value in datasets.items():
    tcrs.append(sample_value["TCRB"])

In [12]:
reference_datasets = set(tcrs[0]).intersection(*tcrs[1:])

In [1]:
from generate_dataset.build_dataset import build_amino_dataset

a = build_amino_dataset("datasets/sum_datasets.npy", "test").__getitem__(0)

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
build_amino_dataset.HLA_list

AttributeError: type object 'build_amino_dataset' has no attribute 'HLA_list'

In [6]:
a[1]

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0.])