## Getting Data

In [14]:
import os
import numpy as np

folder_path = "../cohort_new/numpy arrays"
file_list = os.listdir(folder_path)

data_list = []
for file_name in file_list:
    if file_name.endswith('.npy'):
        file_path = os.path.join(folder_path, file_name)  # Construct the full file path
        array = np.load(file_path).astype(float)          # Load the array and convert its type to float
        data_list.append(array) 

In [15]:
# Squeeze the arrays to remove the redundant dimension
data_list = [arr.squeeze() for arr in data_list]

In [16]:
from scipy.ndimage import zoom
import torch

In [17]:
sample = data_list[0]
# Convert sample size to 147 x 224 x 224

# Set the desired output size
output_size = (147, 224, 224)

# Calculate the zoom factors for each dimension
zoom_factors = (output_size[0] / sample.shape[0],
                output_size[1] / sample.shape[1],
                output_size[2] / sample.shape[2])

# Resize the array using zoom
output_volume_array = zoom(sample, zoom_factors, order=1)

In [28]:
output_volume_array.shape

(147, 224, 224)

In [18]:
# Convert the array to a PyTorch tensor
output_volume_tensor = torch.from_numpy(output_volume_array).unsqueeze(0).unsqueeze(0).float()

In [19]:
output_volume_tensor.shape

torch.Size([1, 1, 147, 224, 224])


---

## Test DataAugmentation

In [24]:
from datasets.datasets_utils_test import DataAugmentationSiT

In [25]:
import argparse

args = argparse.Namespace(drop_perc=0.35, drop_type='noise', 
                          drop_align=(7, 16, 16))
data_augmentation = DataAugmentationSiT(args)

In [29]:
clean_crops, corrupted_crops, masks_crops = data_augmentation(output_volume_array)

In [30]:
clean_crops[0].shape

torch.Size([1, 147, 224, 224])

----

## Test VisionTransformer

* **Q:** How RECHead_3D works on size(2, 4116, 192)? Is the output correct? Cause it tests on size(1, 4116, 192), but what it deals with the new size? (tests in edit_functions notebook)

In [31]:
import vision_transformer_3d_test as vit
from vision_transformer_3d_test import CLSHead, RECHead_3D
import torch.nn as nn

In [32]:
class FullPipline(nn.Module):
    def __init__(self, backbone, head, head_recons):
        super(FullPipline, self).__init__()

        
        backbone.fc, backbone.head = nn.Identity(), nn.Identity()
        self.backbone = backbone
        self.head = head
        self.head_recons = head_recons

    def forward(self, x, recons=True):
        _out = self.backbone(x)
        
        if recons==True:
            print(_out[:, 1:].shape)
            return self.head(_out[:, 0]), self.head_recons(_out[:, 1:])
        else:
            return self.head(_out[:, 0]), None

In [33]:
student = vit.__dict__['vit_tiny']()
student = FullPipline(student, CLSHead(student.embed_dim, 192), RECHead_3D(192))
#student = FullPipline(student, CLSHead(256, 256), nn.Identity())

In [41]:
len(corrupted_crops)

2

In [42]:
corrupted_crops[0].shape

torch.Size([1, 147, 224, 224])

In [44]:
torch.cat(corrupted_crops[0:]).unsqueeze(1).shape

torch.Size([2, 1, 147, 224, 224])

In [45]:
torch.cat(corrupted_crops[0:]).unsqueeze(1).shape

torch.Size([2, 1, 147, 224, 224])

In [46]:
corrupted_crops[0].dtype

torch.float32

In [47]:
s_cls, s_recons  = student(torch.cat(corrupted_crops[0:]).unsqueeze(1))

torch.Size([2, 4116, 192])


In [25]:
s_cls.shape

torch.Size([2, 192])

In [26]:
s_recons.shape

torch.Size([2, 1, 147, 224, 224])

## Test Heads separately

In [19]:
student = vit.__dict__['vit_tiny']()

In [20]:
output_student = student(torch.cat(corrupted_crops[0:]).unsqueeze(1))

KeyboardInterrupt: 

In [None]:
CLSHead(student.embed_dim, 192)(output_student[:, 0]).shape

torch.Size([2, 192])

In [None]:
output_student[:, 1:].shape

torch.Size([2, 4116, 192])

In [None]:
RECHead_3D(192)(output_student[:, 1:][0].unsqueeze(0)).shape

torch.Size([1, 1, 147, 224, 224])

## Test SIMCLR

Is that possible to just remove parts with "distributed" ?! 

In [27]:
@torch.no_grad()
def concat_all_gather(tensor):
    tensors_gather = [torch.ones_like(tensor)
        for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output


In [44]:
# Just a copy-paste!

class SimCLR(nn.Module):
    def __init__(self, temp=0.2):
        super().__init__()
        
        self.temp = temp
        
    def contrastive_loss(self, q, k):
        
        # normalize
        q = nn.functional.normalize(q, dim=1)
        k = nn.functional.normalize(k, dim=1)
        
        # gather all targets
        #k = concat_all_gather(k) # Removed!
        logits = torch.einsum('nc,mc->nm', [q, k]) / self.temp
        N = logits.shape[0] 
        
        #Part of the labels Removed!
        labels = (torch.arange(N, dtype=torch.long) )#+ N * torch.distributed.get_rank())#.cuda()
        return nn.CrossEntropyLoss()(logits, labels) * (2 * self.temp)

    def forward(self, student_output, teacher_output, epoch):

        student_out = student_output
        student_out = student_out.chunk(2)

        teacher_out = teacher_output 
        teacher_out = teacher_out.detach().chunk(2)

        return self.contrastive_loss(student_out[0], teacher_out[1]) + self.contrastive_loss(student_out[1], teacher_out[0])



In [29]:
teacher = vit.__dict__['vit_tiny']()
teacher = FullPipline(teacher, CLSHead(teacher.embed_dim, 192), RECHead_3D(teacher.embed_dim))

In [30]:
t_cls, _  = teacher(torch.cat(clean_crops[0:]).unsqueeze(1), recons=False)

In [31]:
t_cls.shape

torch.Size([2, 192])

In [45]:
# preparing SimCLR loss
simclr_loss = SimCLR(0.2)#.cuda()

In [46]:
c_loss = simclr_loss(s_cls, t_cls, 0)

In [47]:
c_loss

tensor(0., grad_fn=<AddBackward0>)

## Test Reconstruction Loss

In [49]:
import torch.nn.functional as F

In [67]:
torch.cat(masks_crops[0:2]).shape

torch.Size([2, 147, 224, 224])

In [66]:
s_recons.shape

torch.Size([2, 1, 147, 224, 224])

In [65]:
torch.cat(clean_crops[0:]).shape

torch.Size([2, 147, 224, 224])

In [70]:
recloss = F.l1_loss(s_recons, torch.cat(clean_crops[0:]).unsqueeze(1), reduction='none')
r_loss = recloss[torch.cat(masks_crops[0:2]).unsqueeze(1)==1].mean() 

In [71]:
r_loss

tensor(587.2197, grad_fn=<MeanBackward0>)