In [17]:
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split, Subset
import os
import torch
import nibabel as nib
from torchvision.transforms import ToTensor
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score, train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.feature_selection import SelectKBest, chi2
from sklearn.decomposition import PCA
from sklearn.metrics import accuracy_score, recall_score
from torch.nn import TripletMarginLoss, TripletMarginWithDistanceLoss

# Plotting

In [2]:
def plot_3d_tumor(mri_data, threshold=0):
    # Convert to numpy array if it's a torch tensor
    if torch.is_tensor(mri_data):
        mri_data = mri_data.squeeze().cpu().numpy()
    
    # Get the indices of non-zero (or above threshold) voxels
    x, y, z = np.where(mri_data > threshold)
    
    # Get the corresponding intensity values
    intensities = mri_data[x, y, z]

    # Create the 3D scatter plot
    fig = go.Figure(data=[go.Scatter3d(
        x=x,
        y=y,
        z=z,
        mode='markers',
        marker=dict(
            size=3,
            color=intensities,
            colorscale='Viridis',
            opacity=0.8,
            colorbar=dict(title='Intensity')
        )
    )])

    # Update the layout
    fig.update_layout(
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z',
            aspectmode='data'  # This preserves the data aspect ratio
        ),
        title='3D Tumor Visualization'
    )

    return fig

# Model

In [26]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import math 

class conv_block(nn.Module):
    def __init__(self, ch_in, ch_out, k_size, stride=1, padding=1):
        super(conv_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(ch_in, ch_out, kernel_size=k_size, stride=stride, padding=padding, dtype=torch.float32),  
            nn.BatchNorm3d(ch_out, dtype=torch.float32),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        out = self.conv(x)
        return out


class ResNet_block(nn.Module):
    def __init__(self, ch, k_size, stride=1):
        super(ResNet_block, self).__init__()
        self.conv = nn.Sequential(
            conv_block(ch, ch, k_size, stride),
            conv_block(ch, ch, k_size, stride)
        )
        
    def forward(self, x):
        out = self.conv(x) + x
        return out


class up_conv(nn.Module):
    def __init__(self, ch_in, ch_out, k_size=1, scale=2, align_corners=False):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            nn.Conv3d(ch_in, ch_out, kernel_size=k_size),
            nn.Upsample(scale_factor=scale, mode='trilinear', align_corners=align_corners),
        )
    def forward(self, x):
        return self.up(x)

class Encoder(nn.Module):
    """ Encoder module """
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv1 = conv_block(ch_in=1, ch_out=16, k_size=3)
        self.res_block1 = ResNet_block(ch=16, k_size=3)
        self.MaxPool1 = nn.MaxPool3d(3, stride=2, padding=1)

        self.conv2 = conv_block(ch_in=16, ch_out=32, k_size=3)
        self.res_block2 = ResNet_block(ch=32, k_size=3)
        self.MaxPool2 = nn.MaxPool3d(3, stride=2, padding=1)

        self.conv3 = conv_block(ch_in=32, ch_out=64, k_size=3)
        self.res_block3 = ResNet_block(ch=64, k_size=3)
        self.MaxPool3 = nn.MaxPool3d(3, stride=2, padding=1)

        self.conv4 = conv_block(ch_in=64, ch_out=128, k_size=3)
        self.res_block4 = ResNet_block(ch=128, k_size=3)
        self.MaxPool4 = nn.MaxPool3d(3, stride=2, padding=1)
        
        self.conv5 = conv_block(ch_in=128, ch_out=256, k_size=3)
        self.res_block5 = ResNet_block(ch=256, k_size=3)
        self.MaxPool5 = nn.AdaptiveAvgPool3d(output_size=(1, 1, 1))

        self.reset_parameters()
      
    def reset_parameters(self):
        for weight in self.parameters():
            stdv = 1.0 / math.sqrt(weight.size(0))
            torch.nn.init.uniform_(weight, -stdv, stdv)

    def forward(self, x):
        x1 = self.conv1(x)
        x1 = self.res_block1(x1)
        x1 = self.MaxPool1(x1) # torch.Size([1, 32, 26, 31, 26])
        
        x2 = self.conv2(x1)
        x2 = self.res_block2(x2)
        x2 = self.MaxPool2(x2) # torch.Size([1, 64, 8, 10, 8])

        x3 = self.conv3(x2)
        x3 = self.res_block3(x3)
        x3 = self.MaxPool3(x3) # torch.Size([1, 128, 2, 3, 2])
        
        x4 = self.conv4(x3)
        x4 = self.res_block4(x4) # torch.Size([1, 256, 2, 3, 2])
        x4 = self.MaxPool4(x4) # torch.Size([1, 256, 1, 1, 1])
        
        x5 = self.conv5(x4)
        x5 = self.res_block5(x5) # torch.Size([1, 256, 2, 3, 2])
        x5 = self.MaxPool5(x5) # torch.Size([1, 256, 1, 1, 1])
        return x5

class Decoder(nn.Module):
    """ Decoder Module """
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.latent_dim = latent_dim
        self.linear_up = nn.Linear(latent_dim, 256 * 150)
        self.relu = nn.ReLU()
        
        self.upsize5 = up_conv(ch_in=256, ch_out=128, k_size=1, scale=2)
        self.res_block5 = ResNet_block(ch=128, k_size=3)
        
        self.upsize4 = up_conv(ch_in=128, ch_out=64, k_size=1, scale=2)
        self.res_block4 = ResNet_block(ch=64, k_size=3)
        
        self.upsize3 = up_conv(ch_in=64, ch_out=32, k_size=1, scale=(31/20, 2, 2))
        self.res_block3 = ResNet_block(ch=32, k_size=3)        
        
        self.upsize2 = up_conv(ch_in=32, ch_out=1, k_size=1, scale=(1, 2, 2))
        self.res_block2 = ResNet_block(ch=1, k_size=3)   
        
        self.final_conv = ResNet_block(ch=1, k_size=3)
        self.final_upsample = nn.Upsample(size=(31, 202, 233), mode='trilinear', align_corners=False)

        self.reset_parameters()
      
    def reset_parameters(self):
        for weight in self.parameters():
            stdv = 1.0 / math.sqrt(weight.size(0))
            torch.nn.init.uniform_(weight, -stdv, stdv)

    def forward(self, x):
        x5_ = self.linear_up(x)
        x5_ = self.relu(x5_)
        x5_ = x5_.view(-1, 256, 5, 6, 5)
        
        x4_ = self.upsize5(x5_)
        x4_ = self.res_block5(x4_)
        
        x3_ = self.upsize4(x4_) 
        x3_ = self.res_block4(x3_)
        
        x2_ = self.upsize3(x3_) 
        x2_ = self.res_block3(x2_)
        
        x1_ = self.upsize2(x2_) 
        x1_ = self.res_block2(x1_)
        
        x0_ = self.final_upsample(x1_)
        x0_ = self.final_conv(x0_)

        return x0_

In [28]:
class VAE(nn.Module):
    def __init__(self, latent_dim=256):
        super(VAE, self).__init__()
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.latent_dim = latent_dim
        self.z_mean = nn.Linear(256, latent_dim)
        self.z_log_sigma = nn.Linear(256, latent_dim)
        self.epsilon = torch.normal(size=(1, latent_dim), mean=0, std=1.0, device=self.device)
        self.encoder = Encoder()
        self.decoder = Decoder(latent_dim)

        self.reset_parameters()
      
    def reset_parameters(self):
        for weight in self.parameters():
            stdv = 1.0 / math.sqrt(weight.size(0))
            torch.nn.init.uniform_(weight, -stdv, stdv)

    def forward(self, x):
        x = self.encoder(x)
        x = torch.flatten(x, start_dim=1)
#         z_mean = self.z_mean(x)
#         z_log_sigma = self.z_log_sigma(x)
#         z = z_mean + z_log_sigma.exp()*self.epsilon
        y = self.decoder(x)
        y = torch.nn.functional.relu(y, inplace=True)
        return x, y # , z_mean, z_log_sigma

# Losses

In [29]:
class KLDivergence(nn.Module):
    "KL divergence between the estimated normal distribution and a prior distribution"
    def __init__(self):
        super(KLDivergence, self).__init__()
        """
        N :  the index N spans all dimensions of input 
        N = H x W x D
        """
        self.N = 80*96*80
    def forward(self, z_mean, z_log_sigma):
        z_log_var = z_log_sigma * 2
        return 0.5 * ((z_mean**2 + z_log_var.exp() - z_log_var - 1).sum())

class L2Loss(nn.Module): 
    "Measuring the `Euclidian distance` between prediction and ground truh using `L2 Norm`"
    def __init__(self):
        super(L2Loss, self).__init__()
        
    def forward(self, x, y): 
        N = y.shape[0]*y.shape[1]*y.shape[2]*y.shape[3]*y.shape[4]
        return  ( (x - y)**2 ).sum() / N

class L1Loss(nn.Module): 
    "Measuring the `Euclidian distance` between prediction and ground truh using `L1 Norm`"
    def __init__(self):
        super(L1Loss, self).__init__()
        
    def forward(self, x, y): 
        N = y.shape[0]*y.shape[1]*y.shape[2]*y.shape[3]*y.shape[4]
        return  ( (x - y).abs()).sum() / N

# Dataset

In [30]:
class TumorDataset(Dataset):
    def __init__(self, transform=ToTensor(), include_mutation=True, include_no_mutation=True, include_unknown=True, unknown='both'):
        self.transform = transform
        self.image_files = []
        self.mask_files = []
        self.labels = []  # 1 for mutation, 0 for no mutation
        self.is_labeled = []  # New list to track if a sample is labeled

        base_dir = '/kaggle/input/studcampfullfiles'

        if include_mutation:
            self._add_data(f'{base_dir}/series_nii_mutation_anonym/series_nii_mutation_anonym',
                           f'{base_dir}/masks_nii_mutation_anonym/masks_nii_mutation_anonym', 1)

        if include_no_mutation:
            self._add_data(f'{base_dir}/series_nii_no_mutation_anonym/series_nii_no_mutation_anonym',
                           f'{base_dir}/masks_nii_no_mutation_anonym/masks_nii_no_mutation_anonym', 0)

        if include_unknown:
            if unknown == 'both':
                self._add_unknown_data(f'{base_dir}/data_ext/data_ext/series_nii_unknown_less_300_anonym',
                                       f'{base_dir}/data_ext/data_ext/masks_nii_unknown_less_300_anonym')
                self._add_unknown_data(f'{base_dir}/data_ext/data_ext/series_nii_unknown_over_700_anonym',
                                       f'{base_dir}/data_ext/data_ext/masks_nii_unknown_over_700_anonym')
            elif unknown == '300':
                self._add_unknown_data(f'{base_dir}/data_ext/data_ext/series_nii_unknown_less_300_anonym',
                                       f'{base_dir}/data_ext/data_ext/masks_nii_unknown_less_300_anonym')
            else:
                 self._add_unknown_data(f'{base_dir}/data_ext/data_ext/series_nii_unknown_over_700_anonym',
                                       f'{base_dir}/data_ext/data_ext/masks_nii_unknown_over_700_anonym')
                    

    def _add_data(self, image_dir, mask_dir, label):
        image_files = [f for f in os.listdir(image_dir) if f.endswith('.nii')]
        mask_files = [image_file.split('.')[0] + '_label.nii' for image_file in image_files]
        
        self.image_files.extend([os.path.join(image_dir, f) for f in image_files])
        self.mask_files.extend([os.path.join(mask_dir, f) for f in mask_files])
        self.labels.extend([label] * len(image_files))
        self.is_labeled.extend([True] * len(image_files))
    
    def _add_unknown_data(self, image_dir, mask_dir):
        for file in os.listdir(image_dir):
            base_name, ext = os.path.splitext(file)
            if ext.lower() in ['.nii', '.jpg']:  # Handle both .nii and misnamed .jpg files
                image_path = os.path.join(image_dir, file)
                mask_path = os.path.join(mask_dir, base_name + '.nii')  # Mask has the same name as image file
                
                if os.path.exists(mask_path):
                    self.image_files.append(image_path)
                    self.mask_files.append(mask_path)
                    self.labels.append(2)
        self.is_labeled.extend([False] * len(os.listdir(image_dir)))
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        image_path = self.image_files[idx]
        mask_path = self.mask_files[idx]
        label = self.labels[idx]
        
        image = np.array(nib.load(image_path).get_fdata())
        mask = np.array(nib.load(mask_path).get_fdata())
        
        # rotate, augment, anything needed
        image = self.transform(image)
        mask = self.transform(mask)
        
        return image, mask, label

# Extracting tumor

In [31]:
def extract_tumor_region(mri, mask):
    # Find the bounding box of the tumor
    nonzero = torch.nonzero(mask)
    
    mins = torch.min(nonzero, dim=0)[0]
    maxs = torch.max(nonzero, dim=0)[0]
    
    padding = 5
    mins = torch.clamp(mins - padding, min=0)
    maxs = torch.clamp(maxs + padding, max=torch.tensor(mask.shape) - 1)
    
    tumor_mri = mri[mins[0]:maxs[0]+1, mins[1]:maxs[1]+1, mins[2]:maxs[2]+1]
    tumor_mask = mask[mins[0]:maxs[0]+1, mins[1]:maxs[1]+1, mins[2]:maxs[2]+1]
    
    return tumor_mri, tumor_mask

In [32]:
def central_pad(tensor, target_shape):
    pad_sizes = []
    for i in range(len(tensor.shape)):
        diff = target_shape[i] - tensor.shape[i]
        pad_left = diff // 2
        pad_right = diff - pad_left
        pad_sizes.extend([pad_left, pad_right])
    pad_sizes = pad_sizes[::-1]
    return torch.nn.functional.pad(tensor, pad_sizes)

def central_crop(tensor, target_shape):
    current_shape = tensor.shape
    slices = []
    for i in range(len(current_shape)):
        if current_shape[i] > target_shape[i]:
            diff = current_shape[i] - target_shape[i]
            start = diff // 2
            end = start + target_shape[i]
            slices.append(slice(start, end))
        else:
            slices.append(slice(None))
    return tensor[tuple(slices)]

In [9]:
scales = []

In [10]:
def extract_tumors(dataset, fixed_shape, return_label=False):
    extracted_tumors = []
    for mri, mask, label in dataset:
        tumor_mri, tumor_mask = extract_tumor_region(mri, mask)
        # First crop if larger than fixed_shape
        cropped_tumor = central_crop(tumor_mri, fixed_shape)
        scale = torch.median(cropped_tumor)
        scales.append(scale)
        if return_label:
            extracted_tumors.append([cropped_tumor, label])
        else:
            extracted_tumors.append(cropped_tumor)
    if return_label:
        only_tumors = [[central_pad(tumor_mri, fixed_shape).unsqueeze(0).to(torch.float32) / torch.median(tumor_mri), label]
                       for tumor_mri, label in extracted_tumors]
    else:
        only_tumors = [central_pad(tumor_mri, fixed_shape).unsqueeze(0).to(torch.float32) / torch.median(tumor_mri)
                       for tumor_mri in extracted_tumors]
    return only_tumors

In [11]:
def create_split(dataset, train_ratio=0.4, seed=42):
    torch.manual_seed(seed)
    
    labeled_indices = [i for i, (is_labeled, label) in enumerate(zip(dataset.is_labeled, dataset.labels)) if is_labeled and label in [0, 1]]
    unlabeled_indices = [i for i, is_labeled in enumerate(dataset.is_labeled) if not is_labeled]
    
    labeled_dataset = Subset(dataset, labeled_indices)
    
    total_labeled_size = len(labeled_dataset)
    labeled_train_size = int(train_ratio * total_labeled_size)
    val_size = total_labeled_size - labeled_train_size
    
    labeled_train_dataset, val_dataset = random_split(labeled_dataset, [labeled_train_size, val_size])
    
    unlabeled_dataset = Subset(dataset, unlabeled_indices)
    
    train_indices = labeled_train_dataset.indices + unlabeled_indices
    train_dataset = Subset(dataset, train_indices)

    return train_dataset, val_dataset

# Split into train and validate(on which PCA plot is made - model never saw this data)

In [12]:
dataset = TumorDataset(include_mutation=True, include_no_mutation=True, include_unknown=True)
train_dataset, val_dataset = create_split(dataset)

# Finally extracting

In [13]:
only_tumors = extract_tumors(train_dataset, (31, 202, 233), return_label=True)
only_tumors_validate = extract_tumors(val_dataset, (31, 202, 233), return_label=True)

# Training

In [35]:
batch_size = 6
lrate = 0.0001 # 0.001 was best
epochs = 200
np.random.seed(10)
torch.manual_seed(10)

criterion_rec = L2Loss()
criterion_dis = KLDivergence()
# criterion_triplet = nn.TripletMarginWithDistanceLoss(distance_function=lambda x, y: 1 - F.cosine_similarity(x, y), margin=0.2)
criterion_triplet = nn.TripletMarginWithDistanceLoss(margin=1)
dataloader_train = DataLoader(only_tumors, batch_size=batch_size, shuffle=True)
dataloader_validate = DataLoader(only_tumors_validate, batch_size=batch_size, shuffle=True)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
no_images = len(dataloader_train) * batch_size
print("Number of MRI images: ", no_images)

Number of MRI images:  102


In [36]:
# vae_model_params = torch.load('/kaggle/input/vae/pytorch/ver2/1/VAEtumors_epoch_90.pt', map_location=torch.device('cpu'))
# vae_model = VAE()
# vae_model.load_state_dict(vae_model_params)

In [37]:
vae_model = VAE()
optimizer = torch.optim.Adam(vae_model.parameters(), lr=lrate)
scheduler = StepLR(optimizer, step_size=50, gamma=1/2)

In [38]:
 sum(p.numel() for p in vae_model.parameters() if p.requires_grad)

17231321

In [39]:
vae_model.to(device)
vae_model.to(torch.float32)
print()




In [40]:
def train(vae_model, optimizer, scheduler, epochs, dataloader_train, dataloader_val, criterion_rec, criterion_dis, criterion_triplet):
    for epoch in tqdm(range(epochs)):
        loss_rec_epoch, loss_triplet_epoch, total_loss_epoch = 0, 0, 0
        
        # Training phase
        vae_model.train()
        for batch_images, batch_labels in dataloader_train:
            optimizer.zero_grad()
            batch_images = batch_images.to(device)
            batch_images = batch_images.to(torch.float32)
            batch_labels = batch_labels.to(device)
            embeddings, y = vae_model(batch_images)
    
            loss_rec_batch = criterion_rec(batch_images, y)
            loss_triplet_batch = torch.tensor(0.0).to(device)
            
            # Triplet loss, krivo but should work
            mask = (batch_labels == 0) | (batch_labels == 1)
            if mask.sum() >= 3:
                labeled_labels = batch_labels[mask].cpu()

                unique_labels = torch.unique(labeled_labels)
                if len(unique_labels) == 2:
                    embeddings = embeddings.view(batch_images.size(0), -1)
                    labeled_embeddings = embeddings[mask].cpu()
                    
                    anchor_label = np.random.choice([0, 1])
                    other_label = 1 if anchor_label == 0 else 0

                    anchor_embeddings = labeled_embeddings[labeled_labels == anchor_label]
                    if anchor_embeddings.size(0) > 1:
                        anchor = anchor_embeddings[0].unsqueeze(0).to(device)

                        # Choose the positive sample
                        positive = anchor_embeddings[1].unsqueeze(0).to(device)
                            
                        # Choose the negative sample
                        negative_embeddings = labeled_embeddings[labeled_labels == other_label]
                        if negative_embeddings.size(0) > 0:
                            negative = negative_embeddings[0].unsqueeze(0).to(device)
                            loss_triplet_batch = criterion_triplet(anchor, positive, negative)
                
            total_loss_batch = loss_rec_batch + 2 * loss_triplet_batch

            total_loss_batch.backward()
            optimizer.step()
            
            loss_rec_epoch += loss_rec_batch.item() * batch_images.shape[0]
            loss_triplet_epoch += loss_triplet_batch.item() * batch_images.shape[0]
            total_loss_epoch += total_loss_batch.item() * batch_images.shape[0]
        scheduler.step()
        
        if epoch % 10 == 0:
            # Validation phase
            vae_model.eval()
            val_loss_rec = 0
            embeddings = []
            labels = []

            with torch.no_grad():
                for batch_images, batch_labels in dataloader_val:
                    batch_images = batch_images.to(device)
                    batch_images = batch_images.to(torch.float32)
                    _, y = vae_model(batch_images)

                    val_loss_rec_batch = criterion_rec(batch_images, y)
                    val_loss_rec += val_loss_rec_batch.item() * batch_images.shape[0]

                    batch_embeddings = vae_model.encoder(batch_images).view(batch_images.size(0), -1)
                    embeddings.append(batch_embeddings.cpu().numpy())
                    labels.extend(list(batch_labels.numpy()))

            val_loss_rec /= len(dataloader_val.dataset)

            X = pd.DataFrame(np.vstack(embeddings))
            y = pd.Series(labels)
            
            n_components = 2
            pca = PCA(n_components=n_components)
            principal_components = pca.fit_transform(X)
            plt.scatter(principal_components[:, 0], principal_components[:, 1], c=np.array(y))
            plt.show()
            
            X = pd.DataFrame(data=principal_components, columns=[f'Principal Component {i+1}' for i in range(n_components)])
            clf = KNeighborsClassifier(n_neighbors=3)
            clf2 = RandomForestClassifier(random_state=42)
            cv_scores_rec = cross_val_score(clf, X, y, cv=5, scoring='recall')
            cv_scores_acc = cross_val_score(clf, X, y, cv=5, scoring='accuracy')
            print('Val loss rec', val_loss_rec)
            print('KNN', 'Recall:', np.mean(cv_scores_rec), 'Accuracy:', np.mean(cv_scores_acc))
            cv_scores_rec = cross_val_score(clf2, X, y, cv=5, scoring='recall')
            cv_scores_acc = cross_val_score(clf2, X, y, cv=5, scoring='accuracy')
            print('RFC', 'Recall:', np.mean(cv_scores_rec), 'Accuracy:', np.mean(cv_scores_acc))
            
        print('Epoch', epoch, 'Loss rec', loss_rec_epoch/no_images, 'Loss triplet', loss_triplet_epoch/no_images)
        if epoch % 10 == 0:
            torch.save(vae_model.state_dict(), f'VAEtumors_epoch_{epoch}.pt')
            

In [None]:
train(vae_model, optimizer, scheduler, epochs, dataloader_train, dataloader_validate, criterion_rec, criterion_dis, criterion_triplet)

  0%|          | 0/200 [00:00<?, ?it/s]

# Visualise reconstruction

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vae_model.to(device)
vae_model.eval()

with torch.no_grad():
    mri = only_tumors[0][0].unsqueeze(0).to(device)        
    _, output = vae_model(mri)
    output = output * scales[0]
    output = output.squeeze(0).squeeze(0).detach().cpu().numpy()
    fig = plot_3d_tumor(output)
    fig.show()


# Visualise actual tumor

In [24]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fig = plot_3d_tumor(only_tumors[0][0].unsqueeze(0) * scales[0])
fig.show()

# Make latent mutation csv

In [25]:
dataset_mutation = TumorDataset(include_mutation=True, include_no_mutation=False, include_unknown=False)
only_tumors_mutation = extract_tumors(dataset_mutation, (31, 202, 233))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vae_model.to(device)
vae_model.eval()
results = {}
with torch.no_grad():
    for i, filename in enumerate(dataset_mutation.image_files):
        filename = filename.split('/')[-1].split('.nii')[0]
        mri = only_tumors_mutation[i].unsqueeze(0).to(device)        
        output  = vae_model.encoder(mri).view(256).detach().cpu().numpy()
        results[filename] = output
        
df = pd.DataFrame.from_dict(results, orient='index')
df.to_csv('mutation_latents.csv', index=True, index_label='filename')

KeyboardInterrupt: 

# Make latent no mutation csv

In [None]:
dataset_no_mutation = TumorDataset(include_mutation=False, include_no_mutation=True, include_unknown=False)
only_tumors_no_mutation = extract_tumors(dataset_no_mutation, (31, 202, 233))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vae_model.to(device)
vae_model.eval()
results = {}
with torch.no_grad():
    for i, filename in enumerate(dataset_no_mutation.image_files):
        filename = filename.split('/')[-1].split('.nii')[0]
        mri = only_tumors_no_mutation[i].unsqueeze(0).to(device)        
        output  = vae_model.encoder(mri).view(256).detach().cpu().numpy()
        results[filename] = output
        
df = pd.DataFrame.from_dict(results, orient='index')
df.to_csv('no_mutation_latents.csv', index=True, index_label='filename')

# Make latent unknown 700/300 csv

In [None]:
dataset_unknown = TumorDataset(include_mutation=False, include_no_mutation=False, include_unknown=True, unknown='700')
only_tumors_unknown = extract_tumors(dataset_unknown, (31, 202, 233))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vae_model.to(device)
vae_model.eval()
results = {}
with torch.no_grad():
    for i, filename in enumerate(dataset_unknown.image_files):
        filename = filename.split('/')[-1].split('.nii')[0]
        mri = only_tumors_unknown[i].unsqueeze(0).to(device)        
        output  = vae_model.encoder(mri).view(256).detach().cpu().numpy()
        results[filename] = output
        
df = pd.DataFrame.from_dict(results, orient='index')
df.to_csv('unknown_700_latents.csv', index=True, index_label='filename')