# scBERT
This is a modified version of the code provided by the scBERT authors.

## Imports

In [None]:
import os
import gc
import argparse
import json
import random
import math
import random
from functools import reduce
import numpy as np
import pandas as pd
from scipy import sparse
from sklearn.model_selection import train_test_split, ShuffleSplit, StratifiedShuffleSplit, StratifiedKFold
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, precision_recall_fscore_support, classification_report
import torch
from torch import nn
from torch.optim import Adam, SGD, AdamW
from torch.nn import functional as F
from torch.optim.lr_scheduler import StepLR, CosineAnnealingWarmRestarts, CyclicLR
from torch.utils.data import DataLoader, Dataset

from performer_pytorch import PerformerLM
import scanpy as sc
import anndata as ad
from utils import *
import pickle as pkl
import time 

## Hyperparameter Setup

In [None]:
SEED = 2021
EPOCHS = 15
BATCH_SIZE = 1
GRADIENT_ACCUMULATION = 60
LEARNING_RATE = 1e-4
SEQ_LEN = 3001
VALIDATE_EVERY = 1
TEST_EVERY = 1
LOSS = 'ce'
OPTIMIZER = 'adam' 

PATIENCE = 10
UNASSIGN_THRES = 0.0

CLASS = 7
POS_EMBED_USING = True

model_name = 'ms_default'
modelraw = model_name
ckpt_dir = './results/'
data_path = 'ms_default.h5ad'
test_path = 'ms_test.h5ad'

device = torch.device("cuda")

## Class Definitions

In [None]:
class SCDataset(Dataset):
    def __init__(self, data, label):
        super().__init__()
        self.data = data
        self.label = label

    def __getitem__(self, index):
        rand_start = random.randint(0, self.data.shape[0]-1)
        full_seq = self.data[rand_start].toarray()[0]
        full_seq[full_seq > (CLASS - 2)] = CLASS - 2
        full_seq = torch.from_numpy(full_seq).long()
        full_seq = torch.cat((full_seq, torch.tensor([0]))).to(device)
        seq_label = self.label[rand_start]
        return full_seq, seq_label

    def __len__(self):
        return self.data.shape[0]


class Identity(torch.nn.Module):
    def __init__(self, dropout = 0., h_dim = 100, out_dim = 10):
        super(Identity, self).__init__()
        self.conv1 = nn.Conv2d(1, 1, (1, 200))
        self.act = nn.ReLU()
        self.fc1 = nn.Linear(in_features=SEQ_LEN, out_features=512, bias=True)
        self.act1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)
        self.fc2 = nn.Linear(in_features=512, out_features=h_dim, bias=True)
        self.act2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)
        self.fc3 = nn.Linear(in_features=h_dim, out_features=out_dim, bias=True)

    def forward(self, x):
        x = x[:,None,:,:]
        x = self.conv1(x)
        x = self.act(x)
        x = x.view(x.shape[0],-1)
        x = self.fc1(x)
        x = self.act1(x)
        x = self.dropout1(x)
        x = self.fc2(x)
        x = self.act2(x)
        x = self.dropout2(x)
        x = self.fc3(x)
        return x

    
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, num_classes = 7, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.num_classes = num_classes
        self.reduction = reduction

    def forward(self, inputs, targets):
        
        effective_target = torch.eye(self.num_classes)[targets.to('cpu')]

        # Calculate Cross entropy
        logit = F.softmax(inputs, dim=1)
        logit = logit.clamp(1e-7, 1.0) 
        ce = -(effective_target * torch.log(logit.to('cpu')))

        # Calculate Focal Loss
        weight = torch.pow(-logit + 1., self.gamma)
        fl = ce * weight.to('cpu') * self.alpha

        if self.reduction == 'sum':
            return fl.sum()
        elif self.reduction == 'mean':
            return fl.mean()

class LDAMLoss(nn.Module):
    def __init__(self, cls_num_list, max_m=0.5, weight=None, s=30):
        super(LDAMLoss, self).__init__()
        m_list = 1.0 / np.sqrt(np.sqrt(cls_num_list))
        m_list = m_list * (max_m / np.max(m_list))
        m_list = torch.cuda.FloatTensor(m_list)
        self.m_list = m_list
        assert s > 0
        self.s = s
        self.weight = weight

    def forward(self, x, target):
        index = torch.zeros_like(x, dtype=torch.uint8)
        index.scatter_(1, target.data.view(-1, 1), 1)
        
        index_float = index.type(torch.cuda.FloatTensor)
        batch_m = torch.matmul(self.m_list[None, :], index_float.transpose(0,1))
        batch_m = batch_m.view((-1, 1))
        x_m = x - batch_m
    
        output = torch.where(index, x_m, x)
        return F.cross_entropy(self.s*output, target, weight=self.weight)

## Data Setup, Training, and Inference

In [None]:
dataraw = sc.read_h5ad(data_path)
test_data = sc.read_h5ad(test_path)
label_dict, label = np.unique(np.array(dataraw.obs['celltype']), return_inverse=True)
with open('label_dict', 'wb') as fp:
    pkl.dump(label_dict, fp)
with open('label', 'wb') as fp:
    pkl.dump(label, fp)
class_num = np.unique(label, return_counts=True)[1].tolist()
class_weight = torch.tensor([(1 - (x / sum(class_num))) ** 2 for x in class_num])
label = torch.from_numpy(label)
dataraw = dataraw.X

acc = []
f1 = []
f1w = []
pred_list = pd.Series(['un'] * dataraw.shape[0])
test_dataset = SCDataset(test_data.X, test_data.obs['celltype'])

In [None]:
number_split = 0
sss = StratifiedShuffleSplit(n_splits=5, test_size=0.2, random_state=SEED)
for index_train, index_val in sss.split(dataraw, label):
    number_split = number_split + 1
    print(f'Start of split: {number_split}')
    
    model_name = modelraw + str(number_split)
    data_train, label_train = dataraw[index_train], label[index_train]
    data_val, label_val = dataraw[index_val], label[index_val]
    train_dataset = SCDataset(data_train, label_train)
    val_dataset = SCDataset(data_val, label_val)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)


    model = PerformerLM(
        num_tokens = CLASS,
        dim = 200,
        depth = 6,
        max_seq_len = SEQ_LEN,
        heads = 10,
        local_attn_heads = 0,
        g2v_position_emb = POS_EMBED_USING
    )

    path = args.model_path
    ckpt = torch.load(path)
    model.load_state_dict(ckpt['model_state_dict'])
    for param in model.parameters():
        param.requires_grad = False
    for param in model.norm.parameters():
        param.requires_grad = True
    for param in model.performer.net.layers[-2].parameters():
        param.requires_grad = True
    model.to_out = Identity(dropout=0., h_dim=128, out_dim=label_dict.shape[0])
    model = model.to(device)

    # optimizer
    if OPTIMIZER == 'adam':
        optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
    elif OPTIMIZER == 'adamw':
        optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
    elif OPTIMIZER == 'sgd':
        optimizer = SGD(model.parameters(), lr=LEARNING_RATE)
    else:
        raise ValueError(f'Optimizer \'{OPTIMIZER}\' is not a valid option.')
        
    scheduler = CosineAnnealingWarmupRestarts(
        optimizer,
        first_cycle_steps=15,
        cycle_mult=2,
        max_lr=LEARNING_RATE,
        min_lr=1e-6,
        warmup_steps=5,
        gamma=0.9
    )
    
    if LOSS == 'ce':
        loss_fn = nn.CrossEntropyLoss(weight=None).to(device)
    elif LOSS == 'focal':
        loss_fn = FocalLoss(alpha=1, gamma=2, num_classes=len(label_dict)).to(device)
    elif LOSS == 'ldam':
        loss_fn = LDAMLoss(label, max_m=0.5, s=30).to(device)
    else:
        raise ValueError(f'Loss function \'{OPTIMIZER}\' is not a valid option.')
                     
    trigger_times = 0
    max_acc = 0.0
    for i in range(1, EPOCHS+1):
        
        tik = time.time()

        train_loader.sampler.set_epoch(i)
        model.train()
        running_loss = 0.0
        cum_acc = 0.0
        for index, (data, labels) in enumerate(train_loader):
            index += 1
            data, labels = data.to(device), labels.to(device)
            if index % GRADIENT_ACCUMULATION != 0:
                with model.no_sync():
                    logits = model(data)
                    loss = loss_fn(logits, labels)
                    loss.backward()
            if index % GRADIENT_ACCUMULATION == 0:
                logits = model(data)
                loss = loss_fn(logits, labels)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), int(1e6))
                optimizer.step()
                optimizer.zero_grad()
            running_loss += loss.item()
            softmax = nn.Softmax(dim=-1)
            final = softmax(logits)
            final = final.argmax(dim=-1)
            pred_num = labels.size(0)
            correct_num = torch.eq(final, labels).sum(dim=-1)
            cum_acc += torch.true_divide(correct_num, pred_num).mean().item()
        epoch_loss = running_loss / index
        epoch_acc = 100 * cum_acc / index
        tok = time.time()
        train_time = tok - tik 
        max_memory_reserved = torch.cuda.max_memory_reserved() / 1e9

        saving_dir = os.path.join(args.res_dir, model_name, 'train', f'split_{number_split}')
        if not os.path.exists(saving_dir):
            os.makedirs(saving_dir)

        resources_usage_df_path = os.path.join(saving_dir, f'resources_usage.csv')
        if os.path.exists(resources_usage_df_path):
            resources_usage_df = pd.read_csv(resources_usage_df_path)
        else:
            resources_usage_df = pd.DataFrame(columns=['time_per_epoch', 'memory_reserved', 'train_loss'])
        
        
        resources_usage_df = pd.concat([resources_usage_df, pd.DataFrame({'time_per_epoch':[train_time], 'memory_reserved':[max_memory_reserved], 'train_loss':[epoch_loss]})])
        resources_usage_df.to_csv(resources_usage_df_path, index=False)


        if is_master:
            print(f'    ==  Epoch: {i} | Training Loss: {epoch_loss:.6f} | Accuracy: {epoch_acc:6.4f}%  ==')
            print(f'Time: {train_time}')
        scheduler.step()

        if i % VALIDATE_EVERY == 0:
            model.eval()
            running_loss = 0.0
            predictions = []
            truths = []
            with torch.no_grad():
                for index, (data_v, labels_v) in enumerate(val_loader):
                    index += 1
                    data_v, labels_v = data_v.to(device), labels_v.to(device)
                    logits = model(data_v)
                    loss = loss_fn(logits, labels_v)
                    running_loss += loss.item()
                    softmax = nn.Softmax(dim=-1)
                    final_prob = softmax(logits)
                    final = final_prob.argmax(dim=-1)
                    final[np.amax(np.array(final_prob.cpu()), axis=-1) < UNASSIGN_THRES] = -1
                    predictions.append(final)
                    truths.append(labels_v)
                del data_v, labels_v, logits, final_prob, final
                # gather
                predictions = torch.cat(predictions, dim=0)
                truths = torch.cat(truths, dim=0)
                no_drop = predictions != -1
                predictions = np.array((predictions[no_drop]).cpu())
                truths = np.array((truths[no_drop]).cpu())
                cur_acc = accuracy_score(truths, predictions)
                f1 = f1_score(truths, predictions, average='macro')
                val_loss = running_loss / index
                if is_master:
                    print(f'    ==  Epoch: {i} | Validation Loss: {val_loss:.6f} | F1 Score: {f1:.6f}  ==')
                    conf_mat = confusion_matrix(truths, predictions)
                    report = classification_report(truths, predictions, digits=4) #, target_names=label_dict.tolist()
                    report_dict = classification_report(truths, predictions, digits=4, output_dict=True) #, target_names=label_dict.tolist()
                    saving_dir = os.path.join(args.res_dir, model_name, 'val', f'split_{number_split}')

                    if not os.path.exists(saving_dir):
                        os.makedirs(saving_dir)
                        
                    resources_usage_df_path = os.path.join(saving_dir, f'val_loss.csv')
                    if os.path.exists(resources_usage_df_path):
                        resources_usage_df = pd.read_csv(resources_usage_df_path)
                    else:
                        resources_usage_df = pd.DataFrame(columns=['val_loss'])

                    resources_usage_df = pd.concat([resources_usage_df, pd.DataFrame({'val_loss':[val_loss]})])
                    resources_usage_df.to_csv(resources_usage_df_path, index=False)
                    pd.DataFrame(conf_mat).to_csv(os.path.join(saving_dir, f'conf_mat_E{i}.csv'), index=False)
                    pd.DataFrame(report_dict).transpose().to_csv(os.path.join(saving_dir, f'report_E{i}.csv'), index=False)

                    print(conf_mat)
                    print(report)
                    print(label_to_index)
                if cur_acc > max_acc:
                    max_acc = cur_acc
                    trigger_times = 0
                    save_best_ckpt(i, number_split, model, optimizer, scheduler, val_loss, model_name, os.path.join(ckpt_dir, 'val'), predictions)
                else:
                    trigger_times += 1
                    if trigger_times > PATIENCE:
                        break

        if i % TEST_EVERY == 0:
            model.eval()
            running_loss = 0.0
            predictions = []
            truths = []
            with torch.no_grad():
                for index, (data_t, labels_t) in enumerate(test_loader):
                    index += 1
                    print(labels_t)
                    print(data_t)
                    if torch.is_tensor(labels_t):
                        data_t, labels_t = data_t.to(device), labels_t.to(device)
                    else:
                        data_t, labels_t = data_t.to(device), torch.tensor(labels_t).to(device)
                    logits = model(data_t)
                    loss = loss_fn(logits, labels_t)
                    running_loss += loss.item()
                    softmax = nn.Softmax(dim=-1)
                    final_prob = softmax(logits)
                    final = final_prob.argmax(dim=-1)
                    final[np.amax(np.array(final_prob.cpu()), axis=-1) < UNASSIGN_THRES] = -1
                    predictions.append(final)
                    truths.append(labels_t)
                del data_t, labels_t, logits, final_prob, final
                # gather
                predictions = torch.cat(predictions, dim=0)
                truths = torch.cat(truths, dim=0)
                no_drop = predictions != -1
                predictions = np.array((predictions[no_drop]).cpu())
                truths = np.array((truths[no_drop]).cpu())
                cur_acc = accuracy_score(truths, predictions)
                f1 = f1_score(truths, predictions, average='macro')
                test_loss = running_loss / index
                if is_master:
                    print(f'    ==  Epoch: {i} | Test Loss: {test_loss:.6f} | F1 Score: {f1:.6f}  ==')
                    conf_mat = confusion_matrix(truths, predictions)
                    report = classification_report(truths, predictions, digits=4) #, target_names=label_dict.tolist()
                    report_dict = classification_report(truths, predictions, digits=4, output_dict=True) #, target_names=label_dict.tolist()
                    saving_dir = os.path.join(args.res_dir, model_name, 'test', f'split_{number_split}')

                    if not os.path.exists(saving_dir):
                        os.makedirs(saving_dir)
                        
                    resources_usage_df_path = os.path.join(saving_dir, f'test_loss.csv')
                    if os.path.exists(resources_usage_df_path):
                        resources_usage_df = pd.read_csv(resources_usage_df_path)
                    else:
                        resources_usage_df = pd.DataFrame(columns=['test_loss'])

                    resources_usage_df = pd.concat([resources_usage_df, pd.DataFrame({'test_loss':[test_loss]})])
                    resources_usage_df.to_csv(resources_usage_df_path, index=False)
                    
                    pd.DataFrame(conf_mat).to_csv(os.path.join(saving_dir, f'conf_mat_E{i}.csv'), index=False)
                    pd.DataFrame(report_dict).transpose().to_csv(os.path.join(saving_dir, f'report_E{i}.csv'), index=False)

                    print(conf_mat)
                    print(report)
                if cur_acc > max_acc:
                    max_acc = cur_acc
                    trigger_times = 0
                    save_best_ckpt(i, number_split, model, optimizer, scheduler, test_loss, model_name, os.path.join(ckpt_dir, 'test'), predictions)
                else:
                    trigger_times += 1
                    if trigger_times > PATIENCE:
                        break

        del predictions, truths
        torch.cuda.empty_cache()

