In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
#importing dataset
!cp '/content/drive/My Drive/check.zip' '/content/'
!unzip -q "/content/check.zip" 

In [0]:
!pip install pytorch-lightning

In [0]:
from dependencies import *
from models import *
from IPython.display import display # to display images
import torchvision.transforms.functional as TF
import random
from tqdm.auto import tqdm
import torch.nn.functional as F

In [0]:
class MyRotationTransform:
    """Rotate by one of the given angles."""

    def __init__(self, angles):
        self.angles = angles

    def __call__(self, x, mask):
        angle = random.choice(self.angles)
        mask = TF.rotate(mask, angle)
        x = TF.rotate(x, angle, fill=(0,))
      
        return x, mask

class brightTransform:
    """Rotate by one of the given angles."""

    def __init__(self, bightnesses):
        self.bightnesses = bightnesses

    def __call__(self, x):
        bright = random.choice(self.bightnesses)
        return TF.adjust_brightness(x, bright)

In [0]:
#dataset
class OCT_dataset(Dataset):
    def __init__(self, root_dir, train=False):
        self.root_dir = root_dir
        self.train = train
        self.imgs8bit_list = list(sorted(os.listdir(os.path.join(root_dir, "8bit"))))
        self.black_targets_list = list(sorted(os.listdir(os.path.join(root_dir, "black"))))
        self.rotation_transform = MyRotationTransform(angles=[-15, -7, 0, 7, 15])
        self.bright_transform = brightTransform(bightnesses=[0.8, 0.87, 1, 1.12, 1.2])

    def transform_dir(self, image, mask):
        resize = T.Resize(size=(496, 523))
        image = resize(image)
        mask = resize(mask)
        
        #if self.train:
        # Random horizontal flipping
          #if random.random() > 0.75:
            #image = TF.hflip(image)
            #mask = TF.hflip(mask)
          #if random.random() > 1.0:
            #image, mask = self.rotation_transform(image, mask)

        return image, mask

    def transform_image(self, image):
        toPIL = T.ToPILImage()
        image = toPIL(image)

        #if self.train:
        #  if random.random() > 0.6:
        #      image = self.bright_transform(image)

        normalize = T.Normalize((0.5,), (0.5,))
        image = TF.to_tensor(image)
        image = normalize(image)
        return image


    def __len__(self):
        return len(self.imgs8bit_list)

    def __getitem__(self, idx):
        img8bit_instance_path = os.path.join(self.root_dir, "8bit", self.imgs8bit_list[idx])
        black_target_instance_path = os.path.join(self.root_dir, "black", self.black_targets_list[idx])
        
        img8bit = Image.open(img8bit_instance_path)
        black_target = Image.open(black_target_instance_path)
        
        img8bit, black_target = self.transform_dir(img8bit, black_target)
        
        img8bit = np.array(img8bit)
        black_target = np.array(black_target)

        unique_labels = np.unique(black_target) #probably 255
        unique_label = unique_labels[1]
        target_result = np.zeros((img8bit.shape[1]*3))

        for i in range(3):
            black_target_2dim = black_target[:, :, i]
            val_255 = np.where(black_target_2dim == unique_label)
            target_result[val_255[1] + i*523] = val_255[0]
        
        img8bit = self.transform_image(img8bit)
        target_result = torch.from_numpy(target_result)
        return img8bit, target_result, self.imgs8bit_list[idx]


In [0]:
class Swish(nn.Module):
    def __init__(self):
        super(Swish, self).__init__()
    def forward(self, x):
        return x * torch.sigmoid(x)

class ConvBlock(nn.Module):  # not exactly the same but inspired smth
    def __init__(self, input_channels, output_channels, kernel, stride, projectile_dim):
        super().__init__()
        self.expand_cnn = nn.Conv2d(in_channels = input_channels, out_channels = output_channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(num_features = output_channels)

        self.depthwise_conv = nn.Conv2d(in_channels = output_channels, out_channels = output_channels, groups=output_channels, kernel_size=kernel, stride = stride, bias=False)
        self.bn2 = nn.BatchNorm2d(num_features = output_channels)

        self.add_reduce = nn.Conv2d(in_channels=output_channels, out_channels=int(output_channels/2), kernel_size=1)
        self.add_expand = nn.Conv2d(in_channels=int(output_channels/2), out_channels=output_channels, kernel_size=1)

        self.CNNnetwork = nn.Sequential(self.expand_cnn, self.bn1, self.depthwise_conv,self.bn2, self.add_reduce,  self.add_expand)

        self.swish = Swish()

        self.projectile = nn.Conv2d(in_channels=output_channels, out_channels=projectile_dim, kernel_size=1, bias=False)
        #self.pool = nn.MaxPool2d(kernel_size = 2)
        
    def forward(self, data):
        output = self.CNNnetwork(data)
        output = self.swish(output)
        project = self.projectile(output)
        return output, project

In [0]:
class CNNcheck(pl.LightningModule):

    def __init__(self, types = 'cnn'):
        super().__init__()
        self.types = types
        if types == 'cnn':
          self.CNNcell1 = CNNone(input_channels = 1, output_channels = 16)
          self.CNNcell2 = CNNone(input_channels = 16, output_channels = 32)
          self.CNNcell3 = CNNone(input_channels = 32, output_channels = 64)
          self.CNNcell4 = CNNone(input_channels = 64, output_channels = 64)
          self.CNNcell5 = CNNone(input_channels = 64, output_channels = 32)
          self.CNNcell6 = CNNone(input_channels = 32, output_channels = 32)
          #self.CNNcell7 = CNNone(input_channels = 32, output_channels = 16)
          self.CNNnetwork = nn.Sequential(self.CNNcell1, self.CNNcell2, 
                    self.CNNcell3,self.CNNcell4,self.CNNcell5,self.CNNcell6)
          self.fc1 = nn.Linear(in_features = 32*8*7, out_features = 2000)
          self.dropout1 = nn.Dropout(0.5)
          self.fc2 = nn.Linear(in_features = 2000,out_features = 523*3)
        
        elif types == "att":
          self.CNNcell1 = ConvBlock(input_channels = 1, output_channels = 2, kernel = 3, stride = 2, projectile_dim = 2)
          self.CNNcell2 = ConvBlock(input_channels = 2, output_channels = 4, kernel = 3, stride = 2, projectile_dim = 4)
          self.CNNcell3 = ConvBlock(input_channels = 4, output_channels = 8, kernel = 3, stride = 2, projectile_dim = 8)
          self.CNNcell4 = ConvBlock(input_channels = 8, output_channels = 16, kernel = 3, stride = 2, projectile_dim = 16)
          self.CNNcell5 = ConvBlock(input_channels = 16, output_channels = 32, kernel = 3, stride = 2, projectile_dim = 32)
          self.CNNcell6 = ConvBlock(input_channels = 32, output_channels = 64, kernel = 3, stride = 2, projectile_dim = 64)
          
          self.fc_attq = nn.Linear(in_features = 55520, out_features = 1000)
          self.fc_attk = nn.Linear(in_features = 55520, out_features = 1000)
          self.fc_attv =  nn.Linear(in_features = 55520, out_features = 1000)

          self.fc = nn.Linear(in_features = 55520,out_features = 523*3)


    def forward(self, x):
        if self.types == "cnn":
          #x = x[:, None]
          x = self.CNNnetwork(x)
          x = x.view(-1, 32*8*7)
          x = self.dropout1(self.fc1(x))
          x = self.fc2(x)
        
        elif self.types == "att":
          batch_size = x.size(0)
          x, _ = self.CNNcell1(x)
          x, _ = self.CNNcell2(x)
          x, y1 = self.CNNcell3(x)
          x, y2 = self.CNNcell4(x)
          x, y3 = self.CNNcell5(x)
          x, y4 = self.CNNcell6(x)
          y1 = y1.view(batch_size, -1)
          y2 = y2.view(batch_size, -1)
          y3 = y3.view(batch_size, -1)
          y4 = y4.view(batch_size, -1)   
          out = torch.cat((y1, y2,y3, y4), dim=1)
          q = self.fc_attq(out)
          k = self.fc_attk(out)
          v = self.fc_attv(out)

          q = out[:, None]
          k = out[:, None]
          v = out[:, None]
          attn_output_weights = torch.bmm(q, k.transpose(1, 2))
          attn_output_weights = F.softmax(attn_output_weights, dim=-1)

          x = torch.bmm(attn_output_weights, v)
          x = x.view(batch_size, -1)
          x = self.fc(x)

        return x
    
    def prepare_data(self):
        dataset = OCT_dataset(path, train = True)
        dataset_test = OCT_dataset(path, train = False)
        
        indices = torch.randperm(len(dataset)).tolist()
        size_of_test = int(len(dataset) * 0.1)
        size_of_main = len(dataset) - size_of_test
        
        dataset = torch.utils.data.Subset(dataset, indices[:-size_of_test])
        self.dataset_test = torch.utils.data.Subset(dataset_test, indices[-size_of_test:])
        self.dataset_train, self.dataset_val = torch.utils.data.random_split(dataset, [int(size_of_main*0.8), size_of_main - int(size_of_main*0.8)])

    def train_dataloader(self):
        oct_train = DataLoader(self.dataset_train, batch_size=40, num_workers=4)
        return oct_train
    
    def val_dataloader(self):
        oct_val = DataLoader(self.dataset_val, batch_size=40, num_workers=4)
        return oct_val
    
    def test_dataloader(self):
        oct_test = DataLoader(self.dataset_test, batch_size=40, num_workers=4)
        return oct_test
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr = 1e-3)
        return optimizer

    def loss_funtion(self, input, target):
        return torch.sum((input - target) ** 2)
    
    def training_step(self, batch, batch_idx):
        x, y, _ = batch
        input = self.forward(x)
        loss = self.loss_funtion(input, y)

        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        x, y, _ = batch
        input = self.forward(x)
        loss = self.loss_funtion(input, y)
        return {'val_loss': loss}

    def validation_epoch_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'val_loss': avg_loss, 'log': tensorboard_logs}
    
    def test_step(self, batch, batch_idx):
        # OPTIONAL
        x, y, _ = batch
        input = self.forward(x)
        return {'test_loss': self.loss_funtion(input, y)}

    def test_epoch_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        logs = {'test_loss': avg_loss}
        return {'avg_test_loss': avg_loss, 'log': logs, 'progress_bar': logs}


In [0]:
path = "/content"
cnn_oct = CNNcheck(types = 'cnn')
# most basic trainer, uses good defaults (1 gpu)
trainer = pl.Trainer(gpus=1,profiler=True,
                     #auto_lr_find=True, #set hparams
                     gradient_clip_val=0.5,
                     check_val_every_n_epoch=5,
                     #early_stop_callback=True,
                     max_epochs = 600,
                     #min_epochs=400,
                     progress_bar_refresh_rate = 12)    
trainer.fit(cnn_oct)  

In [0]:
trainer.test()

In [0]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/