In [None]:
import numpy as np
import os
join = os.path.join
from tqdm import tqdm
import torch
from torch import nn
from torch.autograd import Function
import torch.nn.functional as F
from torch.utils.data import DataLoader
import monai
from segment_anything import sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide
import cv2

In [None]:
"""
Loads dataset based on the filename (should be a .npz file)
"""

class NpzDataset(Dataset): 
    def __init__(self, npz_filename):
        self.npz_data = np.load(npz_filename)
        self.ori_gts = self.npz_data['gts']
        self.img_embeddings = self.npz_data['img_embeddings']
        print(f"{self.img_embeddings.shape=}, {self.ori_gts.shape=}")

    def __len__(self):
        return self.ori_gts.shape[0]

    def __getitem__(self, index):
        img_embed = self.img_embeddings[index]
        gt2D = self.ori_gts[index]
        y_indices, x_indices = np.where(gt2D > 0)
        x_min, x_max = np.min(x_indices), np.max(x_indices)
        y_min, y_max = np.min(y_indices), np.max(y_indices)
        # add perturbation to bounding box coordinates
        H, W = gt2D.shape
        x_min = max(0, x_min - np.random.randint(0, 20))
        x_max = min(W, x_max + np.random.randint(0, 20))
        y_min = max(0, y_min - np.random.randint(0, 20))
        y_max = min(H, y_max + np.random.randint(0, 20))
        bboxes = np.array([x_min, y_min, x_max, y_max])
        # convert img embedding, mask, bounding box to torch tensor
        return torch.tensor(img_embed).float(), torch.tensor(gt2D[None, :,:]).long(), torch.tensor(bboxes).float()


In [None]:
ALPHA = 10
BETA = 0.75
GAMMA = 10.

class GradReverse(Function):
    lambd = 0

    @staticmethod
    def forward(ctx, x):
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * GradReverse.lambd    

class DomainClassifier(nn.Module):
    """
    A wrapper for the run of encoder-rg-discriminator in order to run net in one
    back-propagation as described in paper.
    """
    def __init__(self, encoder, discriminator):
        super(DomainClassifier, self).__init__()
        self.encoder = encoder
        self.discriminator = discriminator
        self.lambd = 0

    def update_lambd(self, lambd):
        self.lambd = lambd
        GradReverse.lambd = self.lambd

    def forward(self, input):
        x = self.encoder(input)
        x = GradReverse.apply(x)
        x = self.discriminator(x)
        return x

In [None]:
class Classifier(nn.Module):
    """
    A classifier architecture for mnist data.
    """
    def __init__(self):
        super(Classifier, self).__init__()
        # Encoder
        self.conv1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Decoder
        self.upconv1 = nn.ConvTranspose2d(64, 128, kernel_size=2, stride=2)
        self.dec_conv1 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.upconv2 = nn.ConvTranspose2d(256, 256, kernel_size=8, stride=8) # Upsample to 256x256


    def forward(self, input):
        x = self.encode(input)
        # x = torch.flatten(x, start_dim=1)
        x = self.decode(x)
        return x

    def encode(self, input):
        x = F.relu(self.conv1(input))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        return x

    def decode(self, input):
        x = F.relu(self.upconv1(input))
        x = F.relu(self.dec_conv1(x))
        x = self.upconv2(x) # No activation, assuming a multi-class segmentation task
        
        return x

In [None]:
class Discriminator(nn.Module):
    """
    A discriminator architecture adapted for input feature maps with 64 channels.
    """
    def __init__(self):
        super(Discriminator, self).__init__()
        # Convolutional Encoder
        self.conv1 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)  # Adjusted to 64 input channels
        self.conv2 = nn.Conv2d(128, 64, kernel_size=4, stride=2, padding=1)   # Output: 64 x 8 x 8
        self.conv3 = nn.Conv2d(64, 32, kernel_size=4, stride=2, padding=1)    # Output: 32 x 4 x 4

        # Flatten and Dense Layers
        self.flatten = nn.Flatten()
        self.dense1 = nn.Linear(128, 100)  # Adjusted to match the new flattened conv output
        self.dense2 = nn.Linear(100, 1)

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, input):
        # Convolutional layers
        x = F.leaky_relu(self.conv1(input))
        x = F.leaky_relu(self.conv2(x))
        x = F.leaky_relu(self.conv3(x))

        # Flatten and dense layers
        x = self.flatten(x)
        x = self.relu(self.dense1(x))
        x = self.dense2(x)
        x = self.sigmoid(x)
        return x


In [None]:
class GRDomainAdaptation:

    def __init__(self):
        ###########################
        # Initialize Info Holders #
        ########################### 
        # self.args = get_params(source_dataset, experiment='adaptation')
        self.source_dataset = 'mnist'
        self.target_dataset = 'mnist_m'
        self.n_epochs = 50
        self.batch_size_train = 64
        self.batch_size_test = 1000
        self.learning_rate = 0.01
        self.momentum = 0.9
        self.log_interval = 10

        self.random_seed = 1
        self.cuda = True
        self.check_pth = '/home/ubuntu/nadav/GradientReversal/weights/mnist2mnist_m'


        self.source_best_pred = 0.0
        self.target_best_pred = 0.0
        self.best_source_net_state = None
        self.best_target_net_state = None
        self.source_test_losses = []
        self.target_test_losses = []
        self.source_test_acc = []
        self.target_test_acc = []
        self.iters = 0

        #######################################
        # Initialize Source and target labels #
        #######################################
        self.source_disc_labels = torch.zeros(size=(self.batch_size_train, 1)).requires_grad_(False)
        self.target_disc_labels = torch.ones(size=(self.batch_size_train, 1)).requires_grad_(False)
        if self.cuda:
            self.source_disc_labels = self.source_disc_labels.cuda()
            self.target_disc_labels = self.target_disc_labels.cuda()

        """
        Load the source and target dataset into a dataloader for training
        """
        self.source_dataset = NpzDataset('data/demo2D_vit_b/source_dataset.npz')
        self.target_dataset = NpzDataset('data/demo2D_vit_b/target_dataset.npz')

        self.source_train_loader = DataLoader(self.source_dataset, batch_size=8, shuffle=True)
        self.target_train_loader = DataLoader(self.target_dataset, batch_size=8, shuffle=True)

        self.n_batch = min(len(self.target_train_loader), len(self.source_train_loader))

        ##################
        # Define network #
        ##################
        self.net = Classifier()

        if self.cuda:
            self.net = torch.nn.DataParallel(self.net, device_ids=[0])
            self.net = self.net.cuda()

        ###############
        # Set Encoder #
        ###############
        if self.cuda:
            self.encoder = self.net.module.encode
        else:
            self.encoder = self.net.encode

        device = 'cuda:0'
        checkpoint = 'work_dir/SAM/sam_vit_b_01ec64.pth'
        model_type = 'vit_b'        
        self.net = sam_model_registry[model_type](checkpoint=checkpoint).to(device)


        ###################################################
        # Set Domain Classifier (Encoder + Discriminator) #
        ###################################################
        self.discriminator = Discriminator()
        self.domain_classifier = DomainClassifier(self.encoder, self.discriminator)
        if self.cuda:
            self.domain_classifier = torch.nn.DataParallel(self.domain_classifier, device_ids=[0])
            self.domain_classifier = self.domain_classifier.cuda()

        #####################
        # Define Optimizers #
        #####################
        self.net_optimizer = torch.optim.Adam(self.net.mask_decoder.parameters(), lr=1e-5, weight_decay=0)
        self.encoder_optimizer = torch.optim.SGD(self.net.parameters(), self.learning_rate, momentum=self.momentum)
        self.discriminator_optimizer = torch.optim.SGD(self.discriminator.parameters(), lr=self.learning_rate, momentum=self.momentum)

    def train_epoch(self):
        
        self.net.train()
        tbar = tqdm(enumerate(zip(self.source_train_loader, self.target_train_loader)))
        net_loss = 0.0
        disc_loss = 0.0
        total_loss = 0.0

        for i, ((image_embedding, gt2D, boxes), (target_img, _, _)) in enumerate((zip(self.source_train_loader, self.target_train_loader))):
            ##############################
            # update learning parameters #
            ##############################
            print(i)
            self.iters += 1
            p = self.iters / (self.n_epochs * self.n_batch)

            lambd = (2. / (1. + np.exp(-GAMMA * p))) - 1
            if self.cuda:
                self.domain_classifier.module.update_lambd(lambd)
            else:
                self.domain_classifier.update_lambd(lambd)

            lr = self.learning_rate / (1. + ALPHA * p) ** BETA
            self.discriminator_optimizer.lr = lr
            self.net_optimizer.lr = lr
            self.encoder_optimizer.lr = lr

            #########################################################################
            # set batch size in cases where source and target domain differ in size #
            #########################################################################
            curr_batch_size = min(image_embedding.shape[0], target_img.shape[0])
            # image_embedding = image_embedding[:curr_batch_size]
            # gt2D = gt2D[:curr_batch_size]
            # target_img = target_img[:curr_batch_size]
            source_disc_labels = self.source_disc_labels[:curr_batch_size]
            target_disc_labels = self.target_disc_labels[:curr_batch_size]
            if self.cuda:
                image_embedding, gt2D = image_embedding.cuda(), gt2D.cuda()
                target_img = target_img.cuda()

            #######################################################
            # Train network (Encoder + Classifier) on Source Data #
            #######################################################
            device = 'cuda:0'
            with torch.no_grad():
                # convert box to 1024x1024 grid
                box_np = boxes.numpy()
                sam_trans = ResizeLongestSide(self.net.image_encoder.img_size)
                box = sam_trans.apply_boxes(box_np, (gt2D.shape[-2], gt2D.shape[-1]))
                box_torch = torch.as_tensor(box, dtype=torch.float, device=device)
                if len(box_torch.shape) == 2:
                    box_torch = box_torch[:, None, :] # (B, 1, 4)
                # get prompt embeddings 
                sparse_embeddings, dense_embeddings = self.net.prompt_encoder(
                    points=None,
                    boxes=box_torch,
                    masks=None,
                )

            net_output, _ = self.net.mask_decoder(
            image_embeddings=image_embedding.to(device), # (B, 256, 64, 64)
            image_pe=self.net.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
            sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
            dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
            multimask_output=False,
          )
            seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
            class_net_loss = seg_loss(net_output, gt2D.to(device))
            self.net_optimizer.zero_grad()
            class_net_loss.backward()
            self.net_optimizer.step()
            net_loss += class_net_loss

            #########################################
            # Train encoder on Source + Target data #
            #########################################
            self.encoder_optimizer.zero_grad()
            self.discriminator_optimizer.zero_grad()
            dom_input = torch.cat([image_embedding, target_img], dim=0)
            dom_labels = torch.cat([source_disc_labels, target_disc_labels], dim=0)
            dom_output = self.domain_classifier(dom_input)
            dom_loss = F.binary_cross_entropy(dom_output, dom_labels)

            # calculate total loss value
            dom_loss.backward()
            self.discriminator_optimizer.step()
            self.encoder_optimizer.step()
            disc_loss += dom_loss

            total_loss += class_net_loss - lambd * dom_loss
            tbar.set_description('Net loss: {0:.6f}; Discriminator loss: {1:.6f}; Total Loss: {2:.6f}; {3:.2f}%;'.format((net_loss / (i + 1)),
                                                                                                                         (disc_loss / (i + 1)),
                                                                                                                         (total_loss / (i + 1)),
                                                                                                                        (i + 1) / self.n_batch * 100))
        
    def train(self):
        for epoch in range(self.n_epochs):
            print('Epoch: {}; Source Best: {}; Target Best: {}'.format(epoch, self.source_best_pred, self.target_best_pred))
            self.train_epoch()
        output_dir = 'work_dir/demo2D/'
        os.makedirs(output_dir, exist_ok=True)
        # torch.save(self.best_source_net_state, os.path.join(output_dir, 'source_model.pth'))
        torch.save(self.best_target_net_state, os.path.join(output_dir, 'sam_model_best.pth'))

trainer = GRDomainAdaptation()
trainer.train()