In [1]:
import numpy as np
import pandas as pd


from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, f1_score, recall_score, roc_curve


import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

from tqdm import tqdm
import matplotlib.pyplot as plt             #visualisation
import seaborn as sns   #visualisation
from torch.utils.tensorboard import SummaryWriter
%matplotlib inline     
sns.set(color_codes=True)

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    raise Exception("Cry about it")


In [3]:
def add_cell_prefix(df, prefix):
    df.index = [prefix + i for i in df.index]
    return df

#Remove extra quotes from the filenames
def extract_cell_name_smartseq(x):
    y = x.split("_")
    return y[len(y)-2]

def get_cell_name_smartseq(file_name):
    return extract_cell_name_smartseq(file_name)

def convert_indexes_to_cell_names_smartseq(df):
    df.index = [get_cell_name_smartseq(x) for x in df.index]
    return df

def get_cell_hypo_or_norm_smartseq(df_meta, cell_name):
    return df_meta[df_meta["Cell name"]==cell_name]["Condition"].values[0]

def seperate_hypo_and_norm_smartseq(df, df_meta):
    df_hypo = df[df.index.map(lambda x: get_cell_hypo_or_norm_smartseq(df_meta, x)=="Hypo")]
    df_norm = df[df.index.map(lambda x: get_cell_hypo_or_norm_smartseq(df_meta, x)=="Norm" or get_cell_hypo_or_norm_smartseq(df_meta, x)=="Normo")]
    return df_hypo, df_norm

def process_df_smartseq(df, df_meta, prefix):
    df = convert_indexes_to_cell_names_smartseq(df)
    _, df_norm = seperate_hypo_and_norm_smartseq(df,df_meta)
    df = add_cell_prefix(df, prefix)
    df_norm = add_cell_prefix(df_norm, prefix)
    return df, df_norm.index

#Remove extra quotes from the filenames
def extract_cell_name_dropseq(x):
    y = x.split("_")
    return y[0]

def get_cell_name_dropseq(file_name):
    return extract_cell_name_dropseq(file_name)

def convert_indexes_to_cell_names_dropseq(df):
    df.index = [get_cell_name_dropseq(x) for x in df.index]
    return df

def get_cell_hypo_or_norm_dropseq(cell_name):
    return cell_name.split("_")[-1]

def seperate_hypo_and_norm_dropseq(df):
    df_hypo = df[df.index.map(lambda x: get_cell_hypo_or_norm_dropseq(x)=="Hypoxia")]
    df_norm = df[df.index.map(lambda x: get_cell_hypo_or_norm_dropseq(x)=="Normoxia")]
    return df_hypo, df_norm

def process_df_dropseq(df, prefix):
    _, df_norm = seperate_hypo_and_norm_dropseq(df)
    df = convert_indexes_to_cell_names_dropseq(df)
    df_norm = convert_indexes_to_cell_names_dropseq(df_norm)
    df = add_cell_prefix(df, prefix)
    df_norm = add_cell_prefix(df_norm, prefix)
    return df, df_norm.index

In [4]:
df_meta = pd.read_csv("Data/SmartSeq/MCF7_SmartS_MetaData.tsv",delimiter="\t", index_col=0)
dffn = pd.read_csv("Data/SmartSeq/MCF7_SmartS_Filtered_Normalised_3000_Data_train.txt",delimiter=" ",index_col=0).T
df2_meta = pd.read_csv("Data/SmartSeq/HCC1806_SmartS_MetaData.tsv",delimiter="\t",index_col=0)
df2fn = pd.read_csv("Data/SmartSeq/HCC1806_SmartS_Filtered_Normalised_3000_Data_train.txt",delimiter=" ",index_col=0).T
df3 = pd.read_csv("Data/DropSeq/MCF7_Filtered_Normalised_3000_Data_train.txt",delimiter=" ",index_col=0).T
df4 = pd.read_csv("Data/DropSeq/HCC1806_Filtered_Normalised_3000_Data_train.txt",delimiter=" ",index_col=0).T

In [5]:
dffn, dffn_norm_idx = process_df_smartseq(dffn.copy(), df_meta, "MCF7_")
df2fn, df2fn_norm_idx = process_df_smartseq(df2fn.copy(), df2_meta, "HCC1806_")
df3, df3_norm_idx = process_df_dropseq(df3, "MCF7_")
df4, df4_norm_idx = process_df_dropseq(df4, "HCC1806_")

In [6]:
df3

Unnamed: 0,MALAT1,MT-RNR2,NEAT1,H1-5,TFF1,MT-RNR1,H4C3,GDF15,KRT81,MT-CO3,...,MROH1,SKIDA1,MICALL1,RARG,MYO1F,BRWD1-AS2,RPS19BP1,AUNIP,TNK2,SUDS3
MCF7_AAAAACCTATCG,1,0,0,0,4,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
MCF7_AAAACAACCCTA,3,0,0,0,1,0,1,0,0,0,...,0,0,0,0,0,0,0,0,0,0
MCF7_AAAACACTCTCA,3,0,0,0,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
MCF7_AAAACCAGGCAC,6,2,0,0,1,0,1,0,0,0,...,0,0,0,0,0,0,0,0,0,0
MCF7_AAAACCTAGCTC,4,0,0,0,0,0,0,0,0,1,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
MCF7_TTTTCGCGTAGA,0,0,0,0,3,0,7,0,0,0,...,0,0,0,0,0,0,0,0,0,0
MCF7_TTTTCGTCCGCT,1,0,0,0,4,0,1,0,1,0,...,0,0,0,0,0,0,0,0,0,0
MCF7_TTTTCTCCGGCT,0,0,0,1,2,0,4,0,0,0,...,0,0,0,0,0,0,0,0,0,1
MCF7_TTTTGTTCAAAG,0,0,0,0,6,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [7]:
df_drop_norm_idx = np.concatenate([df3_norm_idx, df4_norm_idx])
df_smart_norm_idx = np.concatenate([dffn_norm_idx, df2fn_norm_idx])
df_smart_idx = np.concatenate([dffn.index, df2fn.index])

df_all = pd.concat([dffn, df2fn, df3, df4])
df_all = df_all.fillna(0)

df_all_norm_idx = np.concatenate([dffn_norm_idx, df2fn_norm_idx, df3_norm_idx, df4_norm_idx])

df_MCF7_idx = [idx for idx in df_all.index if "MCF7" in idx]


In [8]:
#df_all["mcf"] = ["MCF7" in idx for idx in df_all.index]
#df_all["smart"] = [idx in df_smart_idx for idx in df_all.index]

In [9]:
df_all = df_all.astype(np.int32)

In [10]:
# Free up memory
del dffn, df2fn, df3, df4, dffn_norm_idx, df2fn_norm_idx, df3_norm_idx, df4_norm_idx, df_meta, df2_meta

In [11]:
import torch.nn.functional as F
import math

class KANLinear(torch.nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        enable_standalone_scale_spline=True,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)

        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:
            self.spline_scaler = torch.nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
        with torch.no_grad():
            noise = (
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:
                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)

    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).

        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features

        grid: torch.Tensor = (
            self.grid
        )  # (in_features, grid_size + 2 * spline_order + 1)
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        """
        Compute the coefficients of the curve that interpolates the given points.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).

        Returns:
            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)

        A = self.b_splines(x).transpose(
            0, 1
        )  # (in_features, batch_size, grid_size + spline_order)
        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)
        solution = torch.linalg.lstsq(
            A, B
        ).solution  # (in_features, grid_size + spline_order, out_features)
        result = solution.permute(
            2, 0, 1
        )  # (out_features, in_features, grid_size + spline_order)

        assert result.size() == (
            self.out_features,
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return result.contiguous()

    @property
    def scaled_spline_weight(self):
        return self.spline_weight * (
            self.spline_scaler.unsqueeze(-1)
            if self.enable_standalone_scale_spline
            else 1.0
        )

    def forward(self, x: torch.Tensor):
        assert x.size(-1) == self.in_features
        original_shape = x.shape
        x = x.view(-1, self.in_features)

        base_output = F.linear(self.base_activation(x), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1),
            self.scaled_spline_weight.view(self.out_features, -1),
        )
        output = base_output + spline_output
        
        output = output.view(*original_shape[:-1], self.out_features)
        return output

    @torch.no_grad()
    def update_grid(self, x: torch.Tensor, margin=0.01):
        assert x.dim() == 2 and x.size(1) == self.in_features
        batch = x.size(0)

        splines = self.b_splines(x)  # (batch, in, coeff)
        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)
        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)
        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)
        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)
        unreduced_spline_output = unreduced_spline_output.permute(
            1, 0, 2
        )  # (batch, in, out)

        # sort each channel individually to collect data distribution
        x_sorted = torch.sort(x, dim=0)[0]
        grid_adaptive = x_sorted[
            torch.linspace(
                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
            )
        ]

        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
        grid_uniform = (
            torch.arange(
                self.grid_size + 1, dtype=torch.float32, device=x.device
            ).unsqueeze(1)
            * uniform_step
            + x_sorted[0]
            - margin
        )

        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        grid = torch.concatenate(
            [
                grid[:1]
                - uniform_step
                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
                grid,
                grid[-1:]
                + uniform_step
                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
            ],
            dim=0,
        )

        self.grid.copy_(grid.T)
        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        """
        Compute the regularization loss.

        This is a dumb simulation of the original L1 regularization as stated in the
        paper, since the original one requires computing absolutes and entropy from the
        expanded (batch, in_features, out_features) intermediate tensor, which is hidden
        behind the F.linear function if we want an memory efficient implementation.

        The L1 regularization is now computed as mean absolute value of the spline
        weights. The authors implementation also includes this term in addition to the
        sample-based regularization.
        """
        l1_fake = self.spline_weight.abs().mean(-1)
        regularization_loss_activation = l1_fake.sum()
        p = l1_fake / regularization_loss_activation
        regularization_loss_entropy = -torch.sum(p * p.log())
        return (
            regularize_activation * regularization_loss_activation
            + regularize_entropy * regularization_loss_entropy
        )


class KAN(torch.nn.Module):
    def __init__(
        self,
        layers_hidden,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KAN, self).__init__()
        self.grid_size = grid_size
        self.spline_order = spline_order

        self.layers = torch.nn.ModuleList()
        for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
            self.layers.append(
                KANLinear(
                    in_features,
                    out_features,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    scale_noise=scale_noise,
                    scale_base=scale_base,
                    scale_spline=scale_spline,
                    base_activation=base_activation,
                    grid_eps=grid_eps,
                    grid_range=grid_range,
                )
            )

    def forward(self, x: torch.Tensor, update_grid=False):
        for layer in self.layers:
            if update_grid:
                layer.update_grid(x)
            x = layer(x)
        return x

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        return sum(
            layer.regularization_loss(regularize_activation, regularize_entropy)
            for layer in self.layers
        )

In [12]:
class NNDataset(Dataset):
    def __init__(self, df, df_norm_idx):
        self.data = df.values  # Convert DataFrame to numpy array
        self.data_norm = df_norm_idx
        self.idx = df.index
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        dat = self.data[index, :]
        x = torch.tensor(dat, dtype=torch.float32, device=device)
        y = torch.tensor(0 if self.idx[index] in self.data_norm else 1,  dtype=torch.float32, device=device)
        return x, y

# Define the architecture of the autoencoder
# Define the architecture of the autoencoder
class Autoencoder2(nn.Module):
    def __init__(self, shrink_sizes, shrink_step_count):
        super(Autoencoder2, self).__init__()

        encoderLayers = []

        for i in range(shrink_step_count):
            encoderLayers.append(KANLinear(shrink_sizes[i], shrink_sizes[i + 1]))
            encoderLayers.append(nn.Dropout(0.2))
        self.encoder = nn.Sequential(*encoderLayers)

        decoderLayers = []
        for i in range(shrink_step_count):
            if (i == shrink_step_count -1):
                decoderLayers.append(KANLinear(shrink_sizes[shrink_step_count - i], shrink_sizes[shrink_step_count - i - 1] + 1))
            else:
                decoderLayers.append(KANLinear(shrink_sizes[shrink_step_count - i], shrink_sizes[shrink_step_count - i - 1]))
        self.decoder = nn.Sequential(*decoderLayers)

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

In [13]:
input_size = df_all.shape[1]
learning_rate = 1e-2

middle_size = 3
shrink_step_count = 2

shrink_poly_factor = 4

shrink_sizes = np.linspace(input_size**(1/shrink_poly_factor), middle_size**(1/shrink_poly_factor), shrink_step_count + 1)
print(shrink_sizes)
# We want integer sizes
shrink_sizes = np.round(shrink_sizes**shrink_poly_factor).astype(int)
print(shrink_sizes)

test_amount = 0.2

[9.64253557 5.47930479 1.31607401]
[8645  901    3]


In [14]:
df_train, df_test = train_test_split(df_all, test_size=test_amount, stratify=df_all.index.isin(df_all_norm_idx))

In [15]:
#df_train_smart = df_train[df_train["smart"] == 1]
#df_train_drop = df_train[df_train["smart"] == 0]

dataset = NNDataset(df_train, df_all_norm_idx)
#dataset_drop = NNDataset(df_train_drop, df_all_norm_idx)
#dataset_smart = NNDataset(df_train_smart, df_all_norm_idx)

data_loader = DataLoader(dataset, batch_size=128, shuffle=True)
#data_loader_drop = DataLoader(dataset_drop, batch_size=128, shuffle=True)
#data_loader_smart = DataLoader(dataset_smart, batch_size=32, shuffle=True)

In [16]:
dataset_test = NNDataset(df_test, df_all_norm_idx)
test_loader = DataLoader(dataset_test, batch_size=64, shuffle=True)

df_test_smart = df_test[[i in df_smart_idx for i in df_test.index]]
df_test_drop = df_test[[i not in df_smart_idx for i in df_test.index]]

dataset_test_smart = NNDataset(df_test_smart, df_smart_norm_idx)
dataset_test_drop = NNDataset(df_test_drop, df_drop_norm_idx)

test_loader_smart = DataLoader(dataset_test_smart, batch_size=64, shuffle=True)
test_loader_drop = DataLoader(dataset_test_drop, batch_size=64, shuffle=True)

In [17]:
# Define the model
model = Autoencoder2(shrink_sizes, shrink_step_count).to(device)
writer = SummaryWriter()

In [35]:
# Define the loss function
criterion = nn.MSELoss()
criterion1 = nn.BCEWithLogitsLoss()


In [19]:
lr_sgd = 1e-2
lr_adm = 1e-3
lr_ada = 1e-2

# Define the optimizer
#sgd = optim.SGD(model.parameters(), lr=lr_sgd)
#adam = optim.Adam(model.parameters(), lr=lr_adm)
ada = optim.Adagrad(model.parameters(), lr=lr_ada)

  from .autonotebook import tqdm as notebook_tqdm


In [20]:
model

Autoencoder2(
  (encoder): Sequential(
    (0): KANLinear(
      (base_activation): SiLU()
    )
    (1): Dropout(p=0.2, inplace=False)
    (2): KANLinear(
      (base_activation): SiLU()
    )
    (3): Dropout(p=0.2, inplace=False)
  )
  (decoder): Sequential(
    (0): KANLinear(
      (base_activation): SiLU()
    )
    (1): KANLinear(
      (base_activation): SiLU()
    )
  )
)

In [21]:
train_config = [
#    [10, data_loader_smart, adam],
    [100, data_loader, ada]
]

In [45]:
total_epochs = 0
# Training loop
for segment in train_config:
    num_epochs = segment[0]
    data_loaderr = segment[1]
    optimizer = segment[2]
    for epoch in tqdm(range(num_epochs)):
        running_loss = 0.0
        running_loss_2 = 0.0
        for inputs, labels in data_loaderr:
            # Zero the parameter gradients
            optimizer.zero_grad()

            outputs = model(inputs)
            encoder_outputs = model(inputs)

            loss = criterion(outputs, torch.cat((inputs, torch.unsqueeze(labels, 1)), 1))

            loss_encoder = criterion1(encoder_outputs[:, 0], labels)
            
            # Backward pass and optimize
            loss.backward()
            loss_encoder.backward()

            # Clip gradients
            #nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            optimizer.step()

            # Print statistics
            running_loss += loss.item()
            running_loss_2 += loss_encoder.item()

        writer.add_scalar("Loss/train", running_loss/len(data_loader), epoch)
        print(f"Epoch {epoch+1}, Loss: {running_loss/len(data_loader)} Loss 2: {running_loss_2/len(data_loader)}", flush=True)
        writer.flush()
        total_epochs += 1

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1, Loss: 4326.149898003495 Loss 2: 0.07249341858145983


  1%|          | 1/100 [00:54<1:30:34, 54.90s/it]

Epoch 2, Loss: 4331.80506784903 Loss 2: 0.07140572446034006


  2%|▏         | 2/100 [01:49<1:29:42, 54.92s/it]

Epoch 3, Loss: 4343.4847284885855 Loss 2: 0.06972691492865915


  3%|▎         | 3/100 [02:44<1:28:50, 54.95s/it]

Epoch 4, Loss: 4345.246312697362 Loss 2: 0.06859141584971677


  4%|▍         | 4/100 [03:39<1:27:58, 54.98s/it]

Epoch 5, Loss: 4355.289168867857 Loss 2: 0.06657501229935366


  5%|▌         | 5/100 [04:34<1:27:05, 55.01s/it]

Epoch 6, Loss: 4325.836973695068 Loss 2: 0.06523025813310043


  6%|▌         | 6/100 [05:29<1:26:10, 55.01s/it]

Epoch 7, Loss: 4343.252809360559 Loss 2: 0.06333338814261166


  7%|▋         | 7/100 [06:25<1:25:18, 55.04s/it]

Epoch 8, Loss: 4333.769238106323 Loss 2: 0.06443924433349267


  8%|▊         | 8/100 [07:20<1:24:25, 55.06s/it]

Epoch 9, Loss: 4342.00848031653 Loss 2: 0.06104344400050848


  9%|▉         | 9/100 [08:15<1:23:31, 55.07s/it]

Epoch 10, Loss: 4332.416694572244 Loss 2: 0.061250629502794016


 10%|█         | 10/100 [09:10<1:22:37, 55.09s/it]

Epoch 11, Loss: 4369.562947458223 Loss 2: 0.059428862895330656


 11%|█         | 11/100 [10:05<1:21:43, 55.10s/it]

Epoch 12, Loss: 4325.583048736596 Loss 2: 0.05828938830805861


 12%|█▏        | 12/100 [11:00<1:20:52, 55.14s/it]

Epoch 13, Loss: 4335.555308571976 Loss 2: 0.05971669812646249


 13%|█▎        | 13/100 [11:55<1:19:59, 55.17s/it]

Epoch 14, Loss: 4331.501453953787 Loss 2: 0.057383777063501916


 14%|█▍        | 14/100 [12:51<1:19:06, 55.19s/it]

Epoch 15, Loss: 4326.937003991332 Loss 2: 0.05630779912614304


 15%|█▌        | 15/100 [13:46<1:18:03, 55.10s/it]

Epoch 16, Loss: 4338.934232120087 Loss 2: 0.05643470059756352


 16%|█▌        | 16/100 [14:40<1:17:01, 55.02s/it]

Epoch 17, Loss: 4325.612469332697 Loss 2: 0.055412677286759665


 17%|█▋        | 17/100 [15:36<1:16:16, 55.14s/it]

Epoch 18, Loss: 4339.170763101591 Loss 2: 0.056413972106478784


 18%|█▊        | 18/100 [16:31<1:15:31, 55.26s/it]

Epoch 19, Loss: 4333.336294890746 Loss 2: 0.05304053880112327


 19%|█▉        | 19/100 [17:26<1:14:31, 55.20s/it]

Epoch 20, Loss: 4329.819340937034 Loss 2: 0.05354480813864781


 20%|██        | 20/100 [18:22<1:13:34, 55.18s/it]

Epoch 21, Loss: 4338.495104729935 Loss 2: 0.051867159755657546


 21%|██        | 21/100 [19:16<1:12:31, 55.08s/it]

Epoch 22, Loss: 4325.309046946401 Loss 2: 0.05134146640320187


 22%|██▏       | 22/100 [20:12<1:11:38, 55.11s/it]

Epoch 23, Loss: 4333.661346984816 Loss 2: 0.05136490673150705


 23%|██▎       | 23/100 [21:07<1:10:40, 55.07s/it]

Epoch 24, Loss: 4333.259348146358 Loss 2: 0.05004951606464127


 24%|██▍       | 24/100 [22:54<1:12:32, 57.27s/it]


KeyboardInterrupt: 

In [46]:
outputs = []
cell_labels = []
with torch.no_grad():
    for inputs, labels in test_loader:
        outputs.extend(nn.Sigmoid()(model.encoder(inputs)).cpu().numpy())
        cell_labels.extend(labels.cpu().numpy())

pred_df = pd.DataFrame(outputs)
pred_df["Condition"] = cell_labels

In [47]:
predd_df = pred_df[pred_df > -100]

In [48]:
predd_df.describe()

Unnamed: 0,0,1,2,Condition
count,7348.0,7348.0,7348.0,7348.0
mean,0.313404,0.216644,0.185044,0.488432
std,0.10827,0.175631,0.186156,0.4999
min,0.0,0.0,0.0,0.0
25%,0.229638,0.039097,0.00289,0.0
50%,0.302299,0.229299,0.184685,0.0
75%,0.337791,0.310325,0.257937,1.0
max,0.5,0.5,0.584872,1.0


In [49]:
import plotly.express as px
fig = px.scatter_3d(predd_df, x=0, y=1, z=2, color="Condition")
fig.show()

In [32]:
torch.save(model.state_dict(), "autoencc.checkpoint")

In [91]:
import gc

model.cpu()
del model, ada
gc.collect()
torch.cuda.empty_cache()