In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch import nn
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold
from Contrast_loss_mem import ContrastLoss
from model import Model
os.environ["CUDA_VISIBLE_DEVICES"]="0"
os.environ['CUDA_LAUNCH_BLOCKING'] = '0'
device = torch.device("cuda")

In [None]:
torch.__version__

'1.9.0+cu102'

In [None]:
model = Model()

In [None]:
def AC_loss(y_true, y_pred):

    x = y_pred[:, :, 1:] - y_pred[:, : , :-1]
    delta_x = x[:, :, :-2] ** 2

    length = torch.mean(torch.sqrt(delta_x + 1e-6), axis=2)


    c1 = torch.ones_like(y_true)
    c2 = torch.zeros_like(y_true)
    region_in = torch.abs(torch.mean(y_pred * ((y_true - c1) ** 2), axis=2))

    region_out = torch.abs(torch.mean((1 - y_pred) * ((y_true - c2) ** 2), axis=2))


    return 6 * length + (region_in + region_out)


def dice_metric(y_true, y_pred):
    intersection = torch.sum(y_pred * y_true)
    smooth = 0.0000001
    dice = (2. * intersection + smooth) / (torch.sum(y_true) + torch.sum(y_pred) + smooth)
    return dice

CE_loss = nn.BCELoss()
contrast_criterion = ContrastLoss()


In [None]:
X_train = np.load('data/X_train.npy')
X_val = np.load('data/X_val.npy')
y_seg_train = np.load('data/y_seg_train.npy')
y_seg_val = np.load('data/y_seg_val.npy')

X_train = X_train.reshape((X_train.shape[0], 1 ,  X_train.shape[1]))
X_val = X_val.reshape((X_val.shape[0], 1, X_val.shape[1]))
y_seg_train = y_seg_train.reshape((y_seg_train.shape[0], 1, y_seg_train.shape[1]))
y_seg_val = y_seg_val.reshape((y_seg_val.shape[0], 1 ,  y_seg_val.shape[1]))



In [None]:
X_train = torch.FloatTensor(X_train)
X_val = torch.FloatTensor(X_val)
y_seg_train = torch.FloatTensor(y_seg_train)
y_seg_val = torch.FloatTensor(y_seg_val)

In [None]:
def _dequeue_and_enqueue(keys, labels,
                             pixel_queue, pixel_queue_ptr):
        batch_size = keys.shape[0]
        feat_dim = keys.shape[1]
        labels = labels + 1
        keys = keys + 1


        for bs in range(batch_size):
            this_feat = keys[bs].contiguous().view(feat_dim, -1)
            this_label = labels[bs].contiguous().view(-1)
            this_label_ids = torch.unique(this_label)
            this_label_ids = [x for x in this_label_ids if x > 0]

            for lb in this_label_ids:
                idxs = (this_label == lb).nonzero()


                num_pixel = idxs.shape[0]
                perm = torch.randperm(num_pixel)
                K = min(num_pixel, self.pixel_update_freq)
                feat = this_feat[:, perm[:K]]
                feat = torch.transpose(feat, 0, 1)
                ptr = int(pixel_queue_ptr[lb])

                if ptr + K >= 1000:
                    pixel_queue[lb, -K:, :] = nn.functional.normalize(feat, p=2, dim=1)
                    pixel_queue_ptr[lb] = 0
                else:
                    pixel_queue[lb, ptr:ptr + K, :] = nn.functional.normalize(feat, p=2, dim=1)
                    pixel_queue_ptr[lb] = (pixel_queue_ptr[lb] + 1) % 1000
        return pixel_queue,pixel_queue_ptr

pixel_queue = torch.zeros([2,5000,128])
pixel_queue_ptr = torch.zeros([2])

In [7]:
import torch.utils.data as Data
train_data = Data.TensorDataset(X_train,y_seg_train)
train_loader = Data.DataLoader(
    dataset=train_data,
    batch_size=64,
    shuffle=True,
)


val_data = Data.TensorDataset(X_val,y_seg_val)
val_loader = Data.DataLoader(
    dataset=val_data,
    batch_size=64,
    shuffle=True,
)

In [8]:
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
optimizer = optim.Adam(
    model.parameters(),
    lr=0.0005,
)


In [9]:
import math
bacth_len = math.ceil(X_train.shape[0]/64)

best_valid_dice = 0
model = model.to(device)


for epoch in range(5000):
    running_loss = 0.0
    model.train()
    for batch_idx, data in enumerate(train_loader, start=0):
        x, labels = data 
        x = x.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()                     
        outputs , proj_feature = model(x)['seg'] , model(x)['embed']  
   

        preds = outputs.clone()
        preds[preds <= 0.5] = 0
        preds[preds > 0.5] = 1
        
        pixel_queue,pixel_queue_ptr = _dequeue_and_enqueue(outputs,labels,pixel_queue,pixel_queue_ptr)


        # loss = CE_loss(outputs, labels)+0.1*contrast_criterion(proj_feature,labels,preds,pixel_queue)  
        loss = torch.mean(AC_loss(labels, outputs))+ 0.1*contrast_criterion(proj_feature,labels,preds,pixel_queue)  
  
        # loss.backward(loss.clone().detach())    
        loss.backward()                           
        optimizer.step()                          
        running_loss += torch.mean(loss).item()

    model.eval()
    with torch.no_grad():
          val_x = X_val.cuda()
          val_outputs = model(val_x)['seg']
          val_y = y_seg_val.cuda()  
  
          val_outputs = np.array(val_outputs.cpu().flatten()).astype(np.int8)
          val_outputs[val_outputs <= 0.5] = 0
          val_outputs[val_outputs > 0.5] = 1
          val_y = np.array(val_y.cpu().flatten()).astype(np.int8)
    
          intersection = np.sum(val_outputs * val_y)
          smooth = 0.00001
          dice = (2. * intersection + smooth) / (np.sum(val_y) + np.sum(val_outputs) + smooth)
          print('val_dice:',dice)

    if dice > best_valid_dice:
           best_valid_dice  = dice
           print("save model")
           torch.save(model, 'Save_Model/model-2023-5-31-1') 
           torch.save(model.state_dict(), 'Save_Model/model_parameter-2023-5-31-1.pkl')

    print('epoch:{} batch_idx:{} loss:{}'
                  .format(epoch+1, batch_idx+1, running_loss/bacth_len))

    running_loss = 0.0 