In [1]:
#load model
#load data
# create new dataset which has more information (country, number of genes etc)
# run a 1 epoch no batching training
# extract the CLS 
#PCA --> Plotting 

import numpy as np
import pandas as pd
import torch 
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchtext
import torchtext.vocab as vocab
from pathlib import Path
import os
import re
from copy import deepcopy


from data_preprocessing import data_loader, data_original
from build_vocabulary import vocab_geno
from build_vocabulary import vocab_pheno
from bert_builder import BERT
from misc import get_paths
from misc import model_loader

from sklearn.decomposition import PCA
############################
model_name = '2024-03-14modelEnc3Emb256Mask0.15ModeTrue.pt'
hyperparameters = re.findall('[A-Z][^A-Z]*', model_name)
numbers = []
for i in range (len(hyperparameters)):
    numbers.append(re.findall('\d+', hyperparameters[i]))

parameters = [int(num) for sublist in numbers for num in sublist if num]


num_enc = parameters[0]
dim_emb = parameters[1]
dim_hidden = parameters[1]
mask_prob = 0

########################
threshold_year = 1970
max_length = [51,44]
mask_prob = 0.15
drop_prob = 0.2
reduced_samples = 1000 

attention_heads = 4 

include_pheno = True   
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#############################

base_dir = Path(os.path.abspath(''))
os.chdir(base_dir)
data_dir, ab_dir, save_directory = get_paths()

print(f"\n Retrieving data from: {data_dir}")
print("Loading data...")
NCBI,ab_df = data_loader(include_pheno,threshold_year,data_dir,ab_dir)
NCBI_geno_only = data_original(threshold_year,data_dir, ab_dir)
print(f"Data correctly loaded, {len(NCBI)} samples found")
print("Creating vocabulary...")
vocabulary_geno = vocab_geno(NCBI_geno_only)
vocabulary_pheno = vocab_pheno(ab_df)

class CLSDataset(Dataset):

    MASKED_INDICES_COLUMN = 'masked_indices'
    TARGET_COLUMN = 'indices'
    TOKEN_MASK_COLUMN = 'token_mask'
    AB_INDEX = 'ab_index'
    SR_CLASS = 'sr_class'
    NUM_GENES = 'num_genes'
    NUM_ABS = 'num_abs'
    LOCATIONS = 'locations'

    def __init__(self,
                 data: pd.DataFrame,
                 vocab_geno: vocab,
                 vocab_pheno: vocab,
                 max_seq_len: list,
                 mask_prob: float,
                 include_pheno:bool,
                 random_state: int = 23,
                 ):
        
        self.random_state = random_state
        np.random.seed(self.random_state)

        CLS = '[CLS]'
        PAD = '[PAD]'
        MASK = '[MASK]'
        UNK = '[UNK]'

        self.include_pheno = include_pheno
        self.data = data.reset_index(drop=True) 
        self.num_samples = self.data.shape[0]
        self.vocab_geno = vocab_geno
        self.vocab_pheno = vocab_pheno
        self.vocab_size_geno = len(self.vocab_geno)
        self.CLS = CLS 
        self.PAD = PAD
        self.MASK = MASK
        self.UNK = UNK
        self.max_seq_len = max_seq_len
        self.mask_prob = mask_prob

        self.columns = [self.MASKED_INDICES_COLUMN, self.TARGET_COLUMN, self.AB_INDEX, self.SR_CLASS, self.NUM_GENES, self.NUM_ABS, self.LOCATIONS]


    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.df.iloc[idx]
        input = torch.tensor(item[self.MASKED_INDICES_COLUMN],device=device).long()
        token_mask  = torch.tensor(item[self.TARGET_COLUMN], device=device).long()
        attention_mask = (input == self.vocab_geno[self.PAD]).unsqueeze(0)
        ab_idx  = torch.tensor(item[self.AB_INDEX], device=device).long()
        sr_class = torch.tensor(item[self.SR_CLASS], device=device).long()
        num_genes = torch.tensor(item[self.NUM_GENES], device=device).long()
        num_abs = torch.tensor(item[self.NUM_ABS], device=device).long()
        locations = torch.tensor(item[self.LOCATIONS], device=device).long()

        return input, token_mask , attention_mask, ab_idx, sr_class, num_genes, num_abs, locations


    def _construct_masking(self):
        sequences = deepcopy(self.data['genes'].tolist())
        masked_sequences = []
        target_indices_list = []
        seq_starts = [[self.CLS, self.data['year'].iloc[i], self.data['location'].iloc[i]] for i in range(self.data.shape[0])]

        for i, geno_seq in enumerate(sequences):
            seq_len = len(geno_seq)
            masking_index = np.random.rand(seq_len) < self.mask_prob   
            target_indices = np.array([-1]*seq_len)
            indices = masking_index.nonzero()[0]
            target_indices[indices] = self.vocab_geno.lookup_indices([geno_seq[i] for i in indices])
            for i in indices:
                r = np.random.rand()
                if r < 0.8:
                    geno_seq[i] = self.MASK
                elif r > 0.9:
                    geno_seq[i] = self.vocab_geno.lookup_token(np.random.randint(self.vocab_size_geno))
            geno_seq = seq_starts[i] + geno_seq
            target_indices = [-1]*3 + target_indices.tolist() 
            masked_sequences.append(geno_seq)
            target_indices_list.append(target_indices)
        masked_sequences = [seq + [self.PAD]*(self.max_seq_len[0] - len(seq)) for seq in masked_sequences]
        for i in range(len(target_indices_list)):
            indices = target_indices_list[i]
            padding = [-1] * (self.max_seq_len[0] - len(indices))
            target_indices_list[i] = indices + padding
        return masked_sequences, target_indices_list 
    
    def _Ab_SR_indexing(self):
        sequences = deepcopy(self.data['AST_phenotypes'].tolist())
        list_idx = []
        list_SR = []
        for i in range(len(sequences)):
            current_seq = sequences[i]
            current_idxs = []
            current_SRs = []
            for j in range(len(current_seq)):
                item = current_seq[j].split('=')
                abs = item[0]   
                sr = item[1]
                current_idxs.append(self.vocab_pheno.lookup_indices([abs]))
                for k in range(len(sr)):
                    if sr == 'R':
                        current_SRs.append(1)
                    else:
                        current_SRs.append(0)
            current_idxs = [int(item[0]) for item in current_idxs]
            for i in range(0,self.max_seq_len[1] - len(current_idxs)):
                current_idxs.append(-1)
            for i in range(0,self.max_seq_len[1] - len(current_SRs)):
                current_SRs.append(-1)
            list_idx.append(current_idxs)
            list_SR.append(current_SRs)
        return list_idx, list_SR
    
    def _num_tested(self):
        ab_sequences = deepcopy(self.data['AST_phenotypes'].tolist())
        gene_sequences = deepcopy(self.data['genes'].tolist())
        num_genes = []
        num_abs = []
        for i in range(len(gene_sequences)):
            current_gene_seq = gene_sequences[i]
            current_ab_seq = ab_sequences[i]

            num_genes.append(len(current_gene_seq))
            num_abs.append(len(current_ab_seq))
        
        return num_genes, num_abs
    
    def _location(self):
        location = deepcopy(self.data['location'].tolist())
        locations = []
        for i in range(len(location)):
            current_location = location[i]

            locations.append(self.vocab_geno.lookup_indices([current_location]))
        
        return locations
    
    
    def prepare_dataset(self):
        masked_sequences, target_indices = self._construct_masking()
        indices_masked = [self.vocab_geno.lookup_indices(masked_seq) for masked_seq in masked_sequences]
        list_idx, list_SR = self._Ab_SR_indexing()
        num_genes, num_abs = self._num_tested()
        locations = self._location()    

        rows = zip(indices_masked, target_indices, list_idx, list_SR,num_genes,num_abs, locations)
        self.df = pd.DataFrame(rows, columns=self.columns)

dataset = CLSDataset(NCBI, vocabulary_geno, vocabulary_pheno, max_length, mask_prob, include_pheno)
dataset.prepare_dataset()
print("Dataset correctly prepared.")

class getCLS:
    def __init__(self, model, dataset, epochs, batch_size,device):
        
        random_seed = 42
        np.random.seed(random_seed)
        torch.manual_seed(random_seed)
        torch.cuda.manual_seed(random_seed)
        torch.backends.cudnn.deterministic = True

        self.model = model

        self.dataset = dataset
        self.epochs = epochs    
        self.batch_size = batch_size
        self.current_epoch  = 0


        self.device = device
    def __call__(self):    
        self._init_result_lists()
        for self.current_epoch in range(self.current_epoch, self.epochs):
            self.data_loader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True)
            self.CLS_out = self.train(self.current_epoch)

            results = {
            "CLS_tokens": self.CLS_out, 
        }
        return results

    def _init_result_lists(self):
        self.CLS_out = []

    def train(self, epoch: int):
        cls_tokens = []
        for i, batch in enumerate(self.data_loader):
            input, token_target, attn_mask, AB_idx, SR_class, filler1, filler2, filler3 = batch

            token_predictions, resistance_predictions, cls_s = self.model(input, attn_mask) 
            print(cls_s.shape)
            cls_tokens.append(cls_s)
            print(len(cls_tokens))

        cls_tokens = torch.cat(cls_tokens, dim=0)
        
        return cls_tokens

model = BERT(vocab_size=len(vocabulary_geno), dim_embedding = dim_emb, dim_hidden=dim_hidden, attention_heads=8, num_encoders=num_enc, dropout_prob=drop_prob, num_ab=len(vocabulary_pheno), device=device).to(device)
trainer = getCLS(model, dataset, 1, 8, device)

results = trainer()

c:\Users\erika\Desktop\Exjobb\repo\base

 Retrieving data from: c:\Users\erika\Desktop\Exjobb\data
Loading data...
Data correctly loaded, 6485 samples found
Creating vocabulary...
Dataset correctly prepared.
torch.Size([8, 256])
1
torch.Size([8, 256])
2
torch.Size([8, 256])
3
torch.Size([8, 256])
4
torch.Size([8, 256])
5
torch.Size([8, 256])
6
torch.Size([8, 256])
7
torch.Size([8, 256])
8
torch.Size([8, 256])
9
torch.Size([8, 256])
10
torch.Size([8, 256])
11
torch.Size([8, 256])
12
torch.Size([8, 256])
13
torch.Size([8, 256])
14
torch.Size([8, 256])
15
torch.Size([8, 256])
16
torch.Size([8, 256])
17
torch.Size([8, 256])
18
torch.Size([8, 256])
19
torch.Size([8, 256])
20
torch.Size([8, 256])
21
torch.Size([8, 256])
22
torch.Size([8, 256])
23
torch.Size([8, 256])
24
torch.Size([8, 256])
25
torch.Size([8, 256])
26
torch.Size([8, 256])
27
torch.Size([8, 256])
28
torch.Size([8, 256])
29
torch.Size([8, 256])
30
torch.Size([8, 256])
31
torch.Size([8, 256])
32
torch.Size([8, 256])
33
torch.Siz

In [2]:
from torch.utils.data import DataLoader 
import torch.nn as nn
import copy
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

print(results.shape)

array_2d = cls_tokens.detach().numpy()

# Create DataFrame from NumPy array
df = pd.DataFrame(array_2d)

principalComponents = pca.fit_transform(df)

principalDf = pd.DataFrame(data = principalComponents
             , columns = ['principal component 1', 'principal component 2'])

principalDf.head()


plt.figure(figsize=(8, 6))
plt.scatter(principalDf['principal component 1'], principalDf['principal component 2'])
plt.title('2D Scatter Plot of Principal Components')
plt.xlabel('Principal Component 1')
plt.ylabel('Principal Component 2')
plt.grid(True)
plt.show()



AttributeError: 'dict' object has no attribute 'shape'