In [None]:
!pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

%matplotlib inline

In [3]:
%load_ext rpy2.ipython

In [None]:
%%R
if (!require("BiocManager", quietly = TRUE))
    install.packages("BiocManager")
BiocManager::install(version = "3.16")

In [None]:
%%R
BiocManager::install("splatter")

In [7]:
%%R
# Load package
suppressPackageStartupMessages({
  library(splatter)
  library(scater)
})

## 1)Pre-processing
### 1.1) Simulate 5000 samples with 2000 genes from six cell types, with the ratio of 5, 10, 10, 20, 25, 30 (fairly imbalanced), and load the samples in Torch tensors.
Please use values for other parameters in a way that simplifies the data.

In [11]:
%%R
sim <- splatSimulate(method="groups", nGenes = 2000, batchCells = 5000 , mean.rate = 0.6,mean.rate = 0.6, dropout.mid = c(6,3,6,3,6,3), dropout.shape=c(-1,-1,-1,-1,-1,-1), dropout.type="group", group.prob = c(0.05, 0.1, 0.1, 0.2, 0.25, 0.3), verbose=FALSE)

In [12]:
%%R
write.table(as.array(counts(sim)), sep='\t', 'data.tsv')
write.table(colData(sim), sep='\t', 'labels.tsv')

In [34]:
import numpy as np
import pandas as pd
from sklearn import preprocessing
from sklearn.model_selection import train_test_split

data = pd.read_csv('data.tsv', sep='\t').transpose().to_numpy()
labels = pd.read_csv('labels.tsv', sep='\t')["Group"].to_numpy()

In [35]:
data = preprocessing.normalize(data)
labels = preprocessing.LabelEncoder().fit_transform(labels)
X_train, X_test, y_train, y_test = train_test_split(data, labels, test_size=0.2, random_state=42)

X_train = torch.from_numpy(X_train).to(torch.float32)
y_train = torch.from_numpy(y_train).to(torch.long)
X_test = torch.from_numpy(X_test).to(torch.float32)
y_test = torch.from_numpy(y_test).to(torch.long)

## 2) Use the model in exercise 4 on the data to classify RNA samples into one of the six cell types. Split the data to train and test sets with a proportionate number of samples in the train and test set. 
### 2.1) Train the model using SGD, Cross entropy loss

Please remember to load the data loader, instantiate the model, optimizer, and loss, and implement the training loop. Please use enough number epochs and proper batch-size, and learning rate to improve the model convergence.
Please calculate test and train loss and accuracy. Also, calculate the test AUC.

In [76]:
# hyperparameters
learning_rate = 0.1
batch_size = 32
epochs = 50

In [77]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cpu device


In [78]:
train_dataset = torch.utils.data.TensorDataset(X_train, y_train) 
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = torch.utils.data.TensorDataset(X_test, y_test) 
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [79]:
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(2000, 500),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(500, 500),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(500, 6),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = x.view(-1, 2000)
        output = self.layers(x)
        return output

In [80]:
def train(model, log_interval=50):
    # Set model to training mode
    model.train()
    size = len(train_loader.dataset)
    # Loop over each batch from the training set
    for batch_idx, (data, target) in enumerate(train_loader):
        output = model(data)
        loss = criterion(output, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_idx % log_interval == 0:
            print(f'Train Loss: {loss.item()}')

In [81]:
def test(model):
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            test_loss += criterion(output, target).item()
            correct += (output.argmax(1) == target).type(torch.float).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f'Test Loss: {test_loss}')

In [82]:
import torch
model = Net().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(model)
    test(model)
print("Done!")

Epoch 1
-------------------------------
Train Loss: 1.7928946018218994
Train Loss: 1.7802987098693848
Train Loss: 1.774113416671753
Test Loss: 0.056557815194129946
Epoch 2
-------------------------------
Train Loss: 1.780104398727417
Train Loss: 1.75491201877594
Train Loss: 1.7663451433181763
Test Loss: 0.055753186106681826
Epoch 3
-------------------------------
Train Loss: 1.7416731119155884
Train Loss: 1.6988426446914673
Train Loss: 1.7150121927261353
Test Loss: 0.0551405885219574
Epoch 4
-------------------------------
Train Loss: 1.7629846334457397
Train Loss: 1.6390866041183472
Train Loss: 1.6644062995910645
Test Loss: 0.054883543133735654
Epoch 5
-------------------------------
Train Loss: 1.7128229141235352
Train Loss: 1.6971194744110107
Train Loss: 1.7092393636703491
Test Loss: 0.054717058181762694
Epoch 6
-------------------------------
Train Loss: 1.618843913078308
Train Loss: 1.7510368824005127
Train Loss: 1.7558908462524414
Test Loss: 0.05457686364650726
Epoch 7
----------

### 2.2) Try to run the same model with Focal Loss to tackle the data imbalancedness. 
Please compare the results with the 2.1 results to show the impact of using focal loss on tackling data imbalancedness.

In [None]:
from torchvision.ops import sigmoid_focal_loss

model2 = Net().to(device)
ce_loss = nn.CrossEntropyLoss()
criterion = sigmoid_focal_loss

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(model2)
    test(model2)
print("Done!")

## 3) Tackle over-fitting
### 3.1) Use [Batch normalization layers](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html) to improve the results. Compare the results with previous ones.

For comparison, please draw AUC curve changes during different epochs (three curves for 2.1, 2.2, and 3.1)

## 4) Improve the results **(Bonus)**
### 4.1) Use the internet resources to further improve the results by changing layers, architecture, optimizer, etc.