In [1]:
import random
import os
import sys
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import numpy as np

import torch.nn as nn
from torch.autograd import Function
from torchvision import datasets, models
from torchvision import transforms

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#constants and hyperparameters
# device = 'cpu'
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
lr = 1e-3
batch_size = 16
image_size = 28
n_epoch = 100
manual_seed = random.randint(1, 10000)
random.seed(manual_seed)
torch.manual_seed(manual_seed)


<torch._C.Generator at 0x7ff8cc6e0110>

## Dataset loading

In [3]:
cataracts_data_dir = "../../data/final_cataracts"
cataracts_train_dir = os.path.join(cataracts_data_dir, 'train')
cataracts_val_dir = os.path.join(cataracts_data_dir, 'val')
test_dir = os.path.join(cataracts_data_dir, 'test')

d99_balanced_datadir = "../../data/final_d99/"
d99_balanced_train_dir = os.path.join(d99_balanced_datadir, 'train')
d99_balanced_val_dir = os.path.join(d99_balanced_datadir, 'val')
d99_balanced_test_dir = os.path.join(d99_balanced_datadir, 'test')

In [4]:
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5726915, 0.35134485, 0.20473212], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5726915, 0.35134485, 0.20473212], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5726915, 0.35134485, 0.20473212], [0.229, 0.224, 0.225])
    ]),
}

image_datasets = {x: datasets.ImageFolder(os.path.join(cataracts_data_dir,x),data_transforms[x]) for x in ['train','val','test']}

cataracts_dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4)
                for x in ['train','val','test']
}

cataracts_dataset_sizes = {x: len(image_datasets[x]) for x in ['train','val','test']}
cataracts_class_names = image_datasets['train'].classes


In [5]:

data_transforms99 = {
    'train': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.46083382, 0.34022495, 0.3280154], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.46083382, 0.34022495, 0.3280154], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.46083382, 0.34022495, 0.3280154], [0.229, 0.224, 0.225])
    ]),
}

image_datasets99_balanced = {x: datasets.ImageFolder(os.path.join(d99_balanced_datadir,x),data_transforms99[x]) for x in ['train','val','test']}

d99_balanced_dataloaders= {x: torch.utils.data.DataLoader(image_datasets99_balanced[x], batch_size=batch_size, shuffle=True, num_workers=4)
                for x in ['train','val','test']
}

d99_balanced_dataset_sizes = {x: len(image_datasets99_balanced[x]) for x in ['train','val','test']}
d99_balanced_class_names = image_datasets99_balanced['train'].classes


## model definition


In [6]:
class ReverseLayerF(Function):

    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha

        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha

        return output, None

In [7]:
class CNNModel(nn.Module):

    def __init__(self, class_names_len, pretrained_weights=''):
        super(CNNModel, self).__init__()
        if pretrained_weights!='':
            self.feature = models.resnet50(pretrained=False)
            self.feature.load_state_dict(torch.load(pretrained_weights),strict=False)
        else:
            # self.feature = models.resnet18(pretrained=True)
            self.feature = models.resnet50(pretrained=True)

        self.in_features = self.feature.fc.in_features
        self.feature.fc = nn.Identity()
        # self.feature = nn.Sequential()
        # self.feature.add_module('f_conv1', nn.Conv2d(3, 64, kernel_size=5))
        # self.feature.add_module('f_bn1', nn.BatchNorm2d(64))
        # self.feature.add_module('f_pool1', nn.MaxPool2d(2))
        # self.feature.add_module('f_relu1', nn.ReLU(True))
        # self.feature.add_module('f_conv2', nn.Conv2d(64, 50, kernel_size=5))
        # self.feature.add_module('f_bn2', nn.BatchNorm2d(50))
        # self.feature.add_module('f_drop1', nn.Dropout2d())
        # self.feature.add_module('f_pool2', nn.MaxPool2d(2))
        # self.feature.add_module('f_relu2', nn.ReLU(True))

        self.class_classifier = nn.Sequential()
        self.class_classifier.add_module('c_fc1', nn.Linear(self.in_features, 100))
        self.class_classifier.add_module('c_bn1', nn.BatchNorm1d(100))
        self.class_classifier.add_module('c_relu1', nn.ReLU(True))
        self.class_classifier.add_module('c_drop1', nn.Dropout())
        self.class_classifier.add_module('c_fc2', nn.Linear(100, 100))
        self.class_classifier.add_module('c_bn2', nn.BatchNorm1d(100))
        self.class_classifier.add_module('c_relu2', nn.ReLU(True))
        self.class_classifier.add_module('c_fc3', nn.Linear(100, class_names_len))
        self.class_classifier.add_module('c_softmax', nn.LogSoftmax(dim=1))

        self.domain_classifier = nn.Sequential()
        self.domain_classifier.add_module('d_fc1', nn.Linear(self.in_features, 100))
        self.domain_classifier.add_module('d_bn1', nn.BatchNorm1d(100))
        self.domain_classifier.add_module('d_relu1', nn.ReLU(True))
        self.domain_classifier.add_module('d_fc2', nn.Linear(100, 2))
        self.domain_classifier.add_module('d_softmax', nn.LogSoftmax(dim=1))

    def forward(self, input_data, alpha):
        # input_data = input_data.expand(input_data.data.shape[0], 3, 28, 28)
        feature = self.feature(input_data)
        feature = feature.view(-1, self.in_features)
        reverse_feature = ReverseLayerF.apply(feature, alpha)
        class_output = self.class_classifier(feature)
        domain_output = self.domain_classifier(reverse_feature)

        return class_output, domain_output

In [8]:
# load model

my_net = CNNModel(class_names_len=len(cataracts_class_names),pretrained_weights='/home/ubuntu/Desktop/Domain_Adaptation_Project/repos/Dann_barlow/pretrained_weights/resnet50.pth')

# setup optimizer

optimizer = optim.SGD(my_net.parameters(), lr=lr, momentum=0.9)

loss_class = torch.nn.NLLLoss()
loss_domain = torch.nn.NLLLoss()

my_net = my_net.to(device)
loss_class = loss_class.to(device)
loss_domain = loss_domain.to(device)

for p in my_net.parameters():
    p.requires_grad = True




In [9]:
# dataloader_source = cataracts_dataloaders
# dataloader_target = d99_balanced_dataloaders
dataloader_target = cataracts_dataloaders
dataloader_source = d99_balanced_dataloaders

In [10]:
def mapper_d99_cat(labels_d99, preds_d99):
    labels_99_arr = labels_d99.cpu().numpy()
    preds_99_arr = preds_d99.cpu().numpy()
    #the keys of this are the labels in the dataset 99 and the values are the corresponding labels in the cataracts dataset
    mapper_dict = {
        0:12,
        1:0,
        2:1,
        3:11,
        4:3,
        5:2,
        6:5,
        7:10,
        8:6,
        9:8,
        10:4,
        11:-1,
        12:9,
        13:-1,
        14:-1,
        15:-1,
        16:-1
    } 

    converted_labels = np.array([mapper_dict[l] for l in labels_99_arr])
    valid_idxs = np.where(converted_labels!=-1)[0]
    valid_labels = converted_labels[valid_idxs]
    valid_preds = preds_99_arr[valid_idxs]
    return torch.Tensor(valid_preds), torch.Tensor(valid_labels)

In [11]:
def test(model_path, dataloader, len_classnames, use_dict=False):
    my_net = CNNModel(len_classnames)
    my_net.load_state_dict(torch.load(model_path))
    my_net = my_net.eval()

    my_net = my_net.to(device)

    len_dataloader = len(dataloader)
    data_target_iter = iter(dataloader)

    i = 0
    n_total = 0
    n_correct = 0

    while i < len_dataloader:

        # test model using target data
        data_target = data_target_iter.next()
        t_img, t_label = data_target

        batch_size = len(t_label)

        t_img = t_img.to(device)
        t_label = t_label.to(device)

        class_output, _ = my_net(input_data=t_img, alpha=0)
        pred = class_output.data.max(1, keepdim=True)[1]

        if use_dict:
            valid_pred, valid_label = mapper_d99_cat(t_label, pred)
        else:
            valid_pred, valid_label = t_label, pred
            
        n_correct += valid_pred.eq(valid_label.data.view_as(valid_pred)).cpu().sum()
        # print(valid_label.size(dim=0))
        n_total += int(valid_label.size(dim=0))

        i += 1

    accu = n_correct.data.numpy() * 1.0 / n_total

    return accu

In [12]:
#load any existing weights
# load_path = '/home/ubuntu/Desktop/Domain_Adaptation_Project/repos/Dann_barlow/pretrained_weights/resnet50.pth'
# load_path = '/home/ubuntu/Desktop/Domain_Adaptation_Project/repos/Dann_barlow/saved_models/src_cata_tgt_d99_model_epoch_current_train_pretrain_bt.pth'
# my_net.load_state_dict(torch.load(load_path))

In [13]:
# training
save_path = './src_d99_tgt_cata_model_simpleDann_current.pth'
best_accu_t = 0.0
for epoch in range(n_epoch):

    len_dataloader = min(len(dataloader_source['train']), len(dataloader_target['train']))
    data_source_iter = iter(dataloader_source['train'])
    data_target_iter = iter(dataloader_target['train'])

    for i in range(len_dataloader):

        p = float(i + epoch * len_dataloader) / n_epoch / len_dataloader
        # alpha = 2. / (1. + np.exp(-10 * p)) - 1
        alpha = 0.5
        # print(alpha)

        # training model using source data
        data_source = data_source_iter.next()
        s_img, s_label = data_source

        my_net.zero_grad()
        batch_size = len(s_label)

        domain_label = torch.zeros(batch_size).long()


        s_img = s_img.to(device)
        s_label = s_label.to(device)
        domain_label = domain_label.to(device)


        class_output, domain_output = my_net(input_data=s_img, alpha=alpha)
        err_s_label = loss_class(class_output, s_label)
        err_s_domain = loss_domain(domain_output, domain_label)

        # training model using target data
        data_target = data_target_iter.next()
        t_img, _ = data_target

        batch_size = len(t_img)

        domain_label = torch.ones(batch_size).long()

        t_img = t_img.to(device)
        domain_label = domain_label.to(device)

        _, domain_output = my_net(input_data=t_img, alpha=alpha)
        err_t_domain = loss_domain(domain_output, domain_label)
        err = err_t_domain + err_s_domain + err_s_label
        err.backward()
        optimizer.step()

        sys.stdout.write('\r epoch: %d, [iter: %d / all %d], err_s_label: %f, err_s_domain: %f, err_t_domain: %f' \
              % (epoch, i + 1, len_dataloader, err_s_label.data.cpu().numpy(),
                 err_s_domain.data.cpu().numpy(), err_t_domain.data.cpu().item()))
        sys.stdout.flush()

        torch.save(my_net.state_dict(), save_path)

    print('\n')
    accu_s = test(save_path, dataloader_source['val'],len(cataracts_class_names))
    print('Accuracy of the %s dataset: %f' % ('cataracts', accu_s))
    accu_t = test(save_path,dataloader_target['val'],len(cataracts_class_names))
    print('Accuracy of the %s dataset: %f\n' % ('dataset99 balanced', accu_t))
    if accu_t > best_accu_t:
        best_accu_s = accu_s
        best_accu_t = accu_t
        torch.save(my_net.state_dict(), './src_cata_tgt_d99_model_epoch_best.pth')

    print("best source acc: ",best_accu_s)
    print("best target acc: ",best_accu_t)

 epoch: 0, [iter: 277 / all 277], err_s_label: 2.105868, err_s_domain: 0.647139, err_t_domain: 0.665206





Accuracy of the cataracts dataset: 0.309798
Accuracy of the dataset99 balanced dataset: 0.182143

best source acc:  0.3097978227060653
best target acc:  0.18214285714285713
 epoch: 1, [iter: 277 / all 277], err_s_label: 1.630829, err_s_domain: 0.560050, err_t_domain: 0.498019

Accuracy of the cataracts dataset: 0.401555
Accuracy of the dataset99 balanced dataset: 0.238690

best source acc:  0.4015552099533437
best target acc:  0.2386904761904762
 epoch: 2, [iter: 277 / all 277], err_s_label: 1.189171, err_s_domain: 0.441395, err_t_domain: 0.449658

Accuracy of the cataracts dataset: 0.420218
Accuracy of the dataset99 balanced dataset: 0.211310

best source acc:  0.4015552099533437
best target acc:  0.2386904761904762
 epoch: 3, [iter: 277 / all 277], err_s_label: 0.849641, err_s_domain: 0.259574, err_t_domain: 0.146070

Accuracy of the cataracts dataset: 0.429549
Accuracy of the dataset99 balanced dataset: 0.235119

best source acc:  0.4015552099533437
best target acc:  0.2386904761904

KeyboardInterrupt: 