In [None]:
import torch
from torchvision.models import efficientnet_b3, EfficientNet_B3_Weights
from torchvision import transforms
import torchvision
import torch.nn.functional as F
from torch.utils.data import random_split, DataLoader
from efficientunet import *
import numpy as np
import matplotlib.pyplot as plt
from torchvision.io import read_image
import os
import kornia
from kornia.augmentation import *
from kornia.utils import get_cuda_or_mps_device_if_available, tensor_to_image

In [None]:
def find_imgs(directory_path):
    imgs = []
    image_extensions = [".jpg", ".jpeg", ".png", ".gif", ".bmp"]  # Add more extensions as needed

    def is_image(filename):
        return any(filename.lower().endswith(extension) for extension in image_extensions)

    for root, _, files in os.walk(directory_path):
        for file in files:
            if is_image(file):
                image_path = os.path.join(root, file)
                imgs.append(image_path)
    return imgs

In [None]:
lungs_path = '/gdrive/MyDrive/JustLungs'
#len(find_imgs('/gdrive/MyDrive/JustLungs'))

In [None]:
class LungDataset(torch.utils.data.Dataset):
  def __init__(self, img_paths, transforms):
    self.img_paths = img_paths
    self.transforms = transforms
    self.img_type = kornia.io.ImageLoadType.GRAY32
    self.randomperspective = RandomPerspective(0.3, "nearest", align_corners=True, same_on_batch=False,keepdim=True, p=0.5)
    self.randomHorizontalflip = RandomHorizontalFlip(same_on_batch=False, keepdim=True, p=0.6, p_batch=0.5)
    self.randomElastic = RandomElasticTransform(alpha=(0.3, 0.3), p=0.5, keepdim=True)
    self.randomRotation = RandomRotation(degrees=20.0, p=0.5,keepdim=True)
    self.randomJigsaw = RandomJigsaw((4, 4), p = 0.3, keepdim=True)

  def __getitem__(self, idx):
    img_org = kornia.io.load_image(self.img_paths[idx], self.img_type, device_k)
    idx_neg = self.get_random_negative_index(idx)
    x_neg = kornia.io.load_image(self.img_paths[idx_neg], self.img_type, device_k)
    if self.transforms is not None:
      img_org = self.transforms(img_org)
      x_neg = self.transforms(x_neg)

    # Good augmentation
    x = self.apply_aug(img_org).squeeze(0)
    x_neg = self.apply_ssl_aug(x_neg).squeeze(0)

    # SSL Augmentations
    x_pos = self.apply_ssl_aug(img_org).squeeze(0)

    return torch.stack([x, x, x]), torch.stack([x_pos, x_pos, x_pos]), torch.stack([x_neg, x_neg, x_neg])

  def get_random_negative_index(self, current_idx):
    # Generate random indices excluding the current index
    indices = torch.randperm(len(self.img_paths))
    idx_neg = indices[indices != current_idx][0].item()
    return idx_neg

  def apply_aug(self, x):
    x = self.randomperspective(x)
    x = self.randomHorizontalflip(x)
    x = self.randomElastic(x)
    x = self.randomRotation(x)
    return x

  def apply_ssl_aug(self,x):
    x = self.randomperspective(x)
    x = self.randomHorizontalflip(x)
    x = self.randomElastic(x)
    x = self.randomRotation(x)
    x = self.randomJigsaw(x)
    return x

  def __len__(self):
    return len(self.img_paths)


In [None]:
all_files = find_imgs(lungs_path)
# logs (1)/
bad_apples = np.load('/gdrive/MyDrive/logs (1)/bad_apple.npy')
train_files = [element for element in all_files if element not in bad_apples]

In [None]:
lung_ds = LungDataset(train_files, transforms.Compose([transforms.Resize(224)]))
# Define the proportions for the train and test sets
train_size = int(0.9 * len(lung_ds))
test_size = len(lung_ds) - train_size

# Use random_split to create train and test datasets
train_dataset, test_dataset = random_split(lung_ds, [train_size, test_size])

# Create DataLoader instances for train and test datasets
train_loader = DataLoader(train_dataset, batch_size=5, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=5, shuffle=False)

In [None]:
b3unet = get_efficientunet_b3(out_channels=1, concat_input=True, pretrained=True)

Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth" to /root/.cache/torch/hub/checkpoints/efficientnet-b3-5fb5a3c3.pth
100%|██████████| 47.1M/47.1M [00:00<00:00, 296MB/s]


In [None]:
class LinearLayer(nn.Module):
    def __init__(self,
                 in_features,
                 out_features,
                 use_bias = True,
                 use_bn = False,
                 **kwargs):
        super(LinearLayer, self).__init__(**kwargs)

        self.in_features = in_features
        self.out_features = out_features
        self.use_bias = use_bias
        self.use_bn = use_bn

        self.linear = nn.Linear(self.in_features,
                                self.out_features,
                                bias = self.use_bias and not self.use_bn)
        if self.use_bn:
             self.bn = nn.BatchNorm1d(self.out_features)

    def forward(self,x):
        x = self.linear(x)
        if self.use_bn:
            x = self.bn(x)
        return x

class ProjectionHead(torch.nn.Module):
    def __init__(self,
                 in_features,
                 hidden_features,
                 out_features,
                 head_type = 'nonlinear',
                 **kwargs):
        super(ProjectionHead,self).__init__(**kwargs)
        self.in_features = in_features
        self.out_features = out_features
        self.hidden_features = hidden_features
        self.head_type = head_type

        if self.head_type == 'linear':
            self.layers = LinearLayer(self.in_features,self.out_features,False, True)
        elif self.head_type == 'nonlinear':
            self.layers = nn.Sequential(
                LinearLayer(self.in_features,self.hidden_features,True, True),
                nn.ReLU(),
                LinearLayer(self.hidden_features,self.out_features,False,True))

    def forward(self,x):
        x = self.layers(x)
        return x

In [None]:
class PreModel(torch.nn.Module):
    def __init__(self,base_model):
        super().__init__()
        self.base_model = base_model
        self.conv_layer = nn.Conv2d(in_channels=1536, out_channels=120, kernel_size=(2,2), stride=(1,1), padding='same')
        for p in self.base_model.parameters():
            p.requires_grad = True
        self.projector = ProjectionHead(5880, 2048, 128)

    def forward(self,x):
        out = self.base_model(x)
        out = self.conv_layer(out)
        out_flat = out.view(out.size(0), -1)
        xp = self.projector(torch.squeeze(out_flat))

        return xp

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [None]:
b3unet = get_efficientunet_b3(out_channels=1, concat_input=True, pretrained=True)
model = PreModel(b3unet.encoder).to(device)

In [None]:
triplet_loss = torch.nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

In [None]:
def save_model(model, optimizer, scheduler, current_epoch, name):
    out = os.path.join('/gdrive/MyDrive/logs/',name.format(current_epoch))

    torch.save({'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict':scheduler.state_dict()}, out)

In [None]:
import time

nr = 0
current_epoch = 0
epochs = 100
tr_loss = []
val_loss = []

for epoch in range(20):

    print(f"Epoch [{epoch}/{epochs}]\t")
    stime = time.time()

    model.train()
    tr_loss_epoch = 0

    for step, (x, x_p, x_n) in enumerate(train_loader):
        optimizer.zero_grad()
        x = x.to(device).float()
        x_p = x_p.to(device).float()
        x_n = x_n.to(device).float()

        # positive pair, with encoding
        z_x = model(x)
        z_xp = model(x_p)
        z_xn = model(x_n)

        loss = triplet_loss(z_x, z_xp, z_xn)
        loss.backward()

        optimizer.step()

        if nr == 0 and step % 50 == 0:
            print(f"Step [{step}/{len(train_loader)}]\t Loss: {round(loss.item(), 5)}")
            import matplotlib.pyplot as plt

            # Display x, x_p, x_n
            fig, axs = plt.subplots(2, 3, figsize=(12, 8))

            axs[0, 0].imshow(x[0].cpu().numpy().squeeze(), cmap='gray')
            axs[0, 0].set_title('x')

            axs[0, 1].imshow(x_p[0].cpu().numpy().squeeze(), cmap='gray')
            axs[0, 1].set_title('x_p')

            axs[0, 2].imshow(x_n[0].cpu().numpy().squeeze(), cmap='gray')
            axs[0, 2].set_title('x_n')

            # Display z_x, z_xp, z_xn
            axs[1, 0].imshow(z_x[0].cpu().numpy().squeeze(), cmap='gray')
            axs[1, 0].set_title('z_x')

            axs[1, 1].imshow(z_xp[0].cpu().numpy().squeeze(), cmap='gray')
            axs[1, 1].set_title('z_xp')

            axs[1, 2].imshow(z_xn[0].cpu().numpy().squeeze(), cmap='gray')
            axs[1, 2].set_title('z_xn')

            plt.tight_layout()
            plt.show()

        if nr == 0 and (epoch+1) % 5 == 0:
          save_model(model, optimizer, scheduler, current_epoch,"SSL_Chest_checkpoint_{}_260621.pt")

        tr_loss_epoch += loss.item()

    lr = optimizer.param_groups[0]["lr"]

    model.eval()
    with torch.no_grad():
        val_loss_epoch = 0
        for step, (x, x_p, x_n) in enumerate(test_loader):

          x = x.to(device).float()
          x_p = x_p.to(device).float()
          x_n = x_n.to(device).float()

          # positive pair, with encoding
          z_x = model(x)
          z_xp = model(x_p)
          z_xn = model(x_n)

          loss = triplet_loss(z_x, z_xp, z_xn)

          if nr == 0 and step % 50 == 0:
              print(f"Step [{step}/{len(test_loader)}]\t Loss: {round(loss.item(),5)}")

          val_loss_epoch += loss.item()

    if nr == 0:
        tr_loss.append(tr_loss_epoch )
        val_loss.append(val_loss_epoch )
        print(f"Epoch [{epoch}/{epochs}]\t Training Loss: {tr_loss_epoch }\t lr: {round(lr, 5)}")
        print(f"Epoch [{epoch}/{epochs}]\t Validation Loss: {val_loss_epoch }\t lr: {round(lr, 5)}")
        current_epoch += 1


Epoch [0/100]	




Step [0/1654]	 Loss: 1.03532
Step [50/1654]	 Loss: 1.95905
Step [100/1654]	 Loss: 2.03513
Step [150/1654]	 Loss: 1.16873
Step [200/1654]	 Loss: 0.6047
Step [250/1654]	 Loss: 0.36417
Step [300/1654]	 Loss: 0.53508
Step [350/1654]	 Loss: 0.49154
Step [400/1654]	 Loss: 0.51618
Step [450/1654]	 Loss: 0.39792
Step [500/1654]	 Loss: 0.33009
Step [550/1654]	 Loss: 0.15012
Step [600/1654]	 Loss: 0.09468
Step [650/1654]	 Loss: 0.23119
Step [700/1654]	 Loss: 0.24123
Step [750/1654]	 Loss: 0.06305
Step [800/1654]	 Loss: 0.13805
Step [850/1654]	 Loss: 0.07746
Step [900/1654]	 Loss: 0.20462
Step [950/1654]	 Loss: 0.0
Step [1000/1654]	 Loss: 0.1018
Step [1050/1654]	 Loss: 0.09122
Step [1100/1654]	 Loss: 0.07124
Step [1150/1654]	 Loss: 0.05503
Step [1200/1654]	 Loss: 0.0
Step [1250/1654]	 Loss: 0.19951
Step [1300/1654]	 Loss: 0.17295
Step [1350/1654]	 Loss: 0.17257
Step [1400/1654]	 Loss: 0.35072
Step [1450/1654]	 Loss: 0.0
Step [1500/1654]	 Loss: 0.0
Step [1550/1654]	 Loss: 0.06161
Step [1600/1654]	



Step [0/184]	 Loss: 0.15382
Step [50/184]	 Loss: 0.00246
Step [100/184]	 Loss: 0.0
Step [150/184]	 Loss: 0.10393
Epoch [0/100]	 Training Loss: 700.5687802694738	 lr: 0.01
Epoch [0/100]	 Validation Loss: 33.07042724266648	 lr: 0.01
Epoch [1/100]	




Step [0/1654]	 Loss: 0.3435
Step [50/1654]	 Loss: 0.00177
Step [100/1654]	 Loss: 0.27973
Step [150/1654]	 Loss: 0.03939
Step [200/1654]	 Loss: 0.0
Step [250/1654]	 Loss: 0.06831
Step [300/1654]	 Loss: 0.0996
Step [350/1654]	 Loss: 0.0
Step [400/1654]	 Loss: 0.1331
Step [450/1654]	 Loss: 0.01274
Step [500/1654]	 Loss: 0.09353
Step [550/1654]	 Loss: 0.05919
Step [600/1654]	 Loss: 0.06387
Step [650/1654]	 Loss: 0.11792
Step [700/1654]	 Loss: 0.02001
Step [750/1654]	 Loss: 0.07656
Step [800/1654]	 Loss: 0.0
Step [850/1654]	 Loss: 0.06752
Step [900/1654]	 Loss: 0.15733
Step [950/1654]	 Loss: 0.13729
Step [1000/1654]	 Loss: 0.26864
Step [1050/1654]	 Loss: 0.0
Step [1100/1654]	 Loss: 0.02004
Step [1150/1654]	 Loss: 0.01127
Step [1200/1654]	 Loss: 0.06795
Step [1250/1654]	 Loss: 0.05661
Step [1300/1654]	 Loss: 0.32548
Step [1350/1654]	 Loss: 0.01167
Step [1400/1654]	 Loss: 0.0344
Step [1450/1654]	 Loss: 0.35279
Step [1500/1654]	 Loss: 0.14962
Step [1550/1654]	 Loss: 0.09445
Step [1600/1654]	 L



Step [0/184]	 Loss: 0.0
Step [50/184]	 Loss: 0.28295
Step [100/184]	 Loss: 0.20475
Step [150/184]	 Loss: 0.14781
Epoch [1/100]	 Training Loss: 196.74497278407216	 lr: 0.01
Epoch [1/100]	 Validation Loss: 29.646721355617046	 lr: 0.01
Epoch [2/100]	




Step [0/1654]	 Loss: 0.14764
Step [50/1654]	 Loss: 0.63326
Step [100/1654]	 Loss: 0.25923
Step [150/1654]	 Loss: 0.10975


Exception: ignored