In [23]:
import itertools
import torch
import torch.optim as optim
import torchvision
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.autograd import Function
from torchsummary import summary
import os
import numpy as np
from collections import defaultdict
import copy

In [5]:
class ReverseLayer(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha

        return output, None

In [6]:
def build_encoder():
    model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet34')
    modules = list(model.children())[:-1] + [torch.nn.Flatten()]
    model = nn.Sequential(*modules)
    return model

def build_classifier(input_shape, classes, joint=True):
    domains = 2 if joint else 1
    
    classifier = nn.Sequential(
        nn.Linear(input_shape, 1280),
        nn.ReLU(True),
        nn.Dropout(0.2),
        nn.Linear(1280, 1280),
        nn.ReLU(True),
        nn.Dropout(0.2),
        nn.Linear(1280, domains*classes),
        nn.Softmax()
    )
    return classifier

def build_discriminator(input_shape):
    discriminator = nn.Sequential(
        nn.Linear(input_shape, 1280),
        nn.ReLU(True),
        nn.Linear(1280, 1280),
        nn.ReLU(True),
        nn.Linear(1280, 2),
        nn.Softmax()
    )
    return discriminator

In [7]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data, labels=True):
        self.data = data
        self.labels = labels
        
    def __getitem__(self, index):
        x = self.data[index]
        return x if self.labels else x[0]
    
    def __len__(self):
        return len(self.data)

In [187]:
class DataGenerator:
    def __init__(
        self, 
        source_domain,
        target_domain,
        val_split=0.2,
        test_split=0.2,
        input_shape=(224,224),
        target_labels=0.1,
        target_train=False
    ):
        src_train, src_val = self.__prepare_data(
            source_domain, 
            input_shape,
            True, 
            val_split,
            test_split,
            target_labels
        )
              
        train_label, train_nlabel, tar_val, test = self.__prepare_data(
            target_domain,  
            input_shape, 
            False, 
            val_split,
            test_split,
            target_labels
        )
        
        val = torch.utils.data.ConcatDataset([src_val, tar_val])
        
        if target_train:
            train_label_set = train_label
        else:
            train_label_set = torch.utils.data.ConcatDataset([src_train, train_label])

        self.train_label_loader = DataLoader(dataset=train_label_set, batch_size=32, shuffle=True, num_workers=0)
        if target_labels < 1:
            self.train_nlabel_loader = DataLoader(dataset=train_nlabel, batch_size=32, shuffle=True, num_workers=0)
        else:
            self.train_nlabel_loader = None
        self.val_loader = DataLoader(dataset=val, batch_size=64, shuffle=True, num_workers=0)
        self.test_loader = DataLoader(dataset=test, batch_size=64, shuffle=True, num_workers=0)
    
    def __prepare_data(self, folder, input_shape, src=True, val_split=0, test_split=0, target_labels=0.1):
        
        transform = transforms.Compose([
            transforms.Resize(input_shape),
            transforms.ToTensor()
        ])
        
        label = 0 if src else 1
        
        data = torchvision.datasets.ImageFolder(folder, transform=transform)
        data.target_transform = lambda id: torch.Tensor((label, id))
        
        self.classes = data.classes
        
        if src:
            
            if target_labels == 0:
                train, val = torch.utils.data.random_split(
                    data, 
                    [round(len(data)*(1-val_split) - 1e-5), round(len(data)*val_split + 1e-5)]
                ) 
            else:
                train, val = torch.utils.data.random_split(
                    data, 
                    [len(data), 0]
                )  
            
            return train, val
        
        else:
            data, test = torch.utils.data.random_split(
                data, 
                [round(len(data)*(1-test_split) - 1e-5), round(len(data)*test_split + 1e-5)]
            )
        
            train, train_nlabel = torch.utils.data.random_split(
                data, 
                [round(len(data)*target_labels - 1e-5), round(len(data)*(1-target_labels) + 1e-5)]
            ) 
    
            train_label, val = torch.utils.data.random_split(
                train, 
                [round(len(train)*(1-val_split) - 1e-5), round(len(train)*val_split + 1e-5)]
            )
            
            return train_label, Dataset(train_nlabel, False), val, test
        
    def train_data(self):
        return self.train_label_loader, self.train_nlabel_loader
    
    def val_data(self):
        return self.val_loader
    
    def test_data(self):
        return self.test_loader

In [9]:
def KL_Loss(y_pred, classes):
    y_joint = torch.reshape(y_pred, (-1, 2*classes))
    
    y_class = torch.unsqueeze(torch.sum(y_pred, 1), 1)
    y_domain = torch.unsqueeze(torch.sum(y_pred, 2), -1)
    
    
    y_ind_joint = torch.reshape((y_domain * y_class), (-1,2*classes))
    
    return torch.nn.KLDivLoss(log_target=True, reduction="sum")(
        torch.log(y_joint), 
        torch.log(y_ind_joint)
    )

In [None]:
class DANNModel(nn.Module):
    def __init__(
        self, 
        classes=65,
    ):
        super(SingleDomainModel, self).__init__()
        self.classes = classes
        self.encoder = build_encoder()
        self.classifier = build_classifier(512, self.classes)
        self.discriminator = build_discriminator(512)
    
    def forward(self, inputs, rep_loss):
        features = self.encoder(inputs)
        classes = self.classifier(features)
        classes = torch.reshape(classes, (-1, 2, self.classes))
        domains = self.discriminator(ReverseLayer.apply(features, rep_loss))
        return classes, domains
    
    def run_batch(self, imgs, labels=None, training=True, use_KL=True, rep_loss=1, risk_loss=1):
        if training:
            opt = optim.Adam(self.parameters(), lr=1e-4)
#             opt = torch.optim.SGD(m.parameters(), lr=1e-4, momentum=0.9, weight_decay=5e-4)
            opt.zero_grad()
            
        kl_loss = 0
        class_pred_loss = 0
        pred_count_joint = 0
        pred_count_class = 0
        
        imgs = imgs.cuda()
        y_pred, domain_pred_label = self(imgs, rep_loss) 
        
        if labels is not None:
            labels = labels.cuda()
            labels = labels.to(dtype=torch.int64)
            joint_labels = labels[:,0] * m.classes + labels[:,1]
            dom_labels = torch.cat((labels[:,0], torch.ones(imgs.shape[0] - labels.shape[0], dtype=torch.long).cuda()), 0)

            y_pred = y_pred[:labels.shape[0]]
            y_joint = torch.reshape(y_pred, (-1,2*m.classes))
            y_class = torch.sum(y_pred, 1)

            if use_KL:
                kl_loss = KL_Loss(y_pred, m.classes) * risk_loss

            class_pred_loss = torch.nn.NLLLoss()(torch.log(y_class), labels[:,1])
            
            pred_count_joint = torch.count_nonzero(torch.argmax(y_joint, 1) == joint_labels)
            pred_count_class = torch.count_nonzero(torch.argmax(y_class, 1) == labels[:,1])
             
        dis_loss = torch.nn.NLLLoss()(torch.log(domain_pred_label), dom_labels)
        dis_count = torch.count_nonzero(torch.argmax(domain_pred_label, 1) == dom_labels)
        
        loss = class_pred_loss + kl_loss + dis_loss

        if training:
            loss.backward()
            opt.step()
        
        return (
            float(loss), 
            float(kl_loss), 
            float(class_pred_loss), 
            float(dis_loss), 
            float(pred_count_joint), 
            float(pred_count_class),
            float(dis_count)
        )

In [164]:
class SingleDomainModel(nn.Module):
    def __init__(
        self, 
        classes=65,
        use_KL=True,
        risk_lambda=1,
        lr=1e-4
    ):
        super(SingleDomainModel, self).__init__()
        self.classes = classes
        self.encoder = build_encoder()
        self.classifier = build_classifier(512, self.classes, use_KL)
        
        self.lr = lr
        
        self.use_KL = use_KL
        self.risk_lambda = risk_lambda
        
    def forward(self, inputs):
        features = self.encoder(inputs)
        classes = self.classifier(features)
        if self.use_KL:
            classes = torch.reshape(classes, (-1, 2, self.classes))
        return classes
    
    def __print_metrics(self, metrics, prefix="", precision=4, space=1):
        for k, v in metrics.items():
            print(prefix + k, round(v, precision), sep=": ", end=' ')
        for i in range(space):
            print()
    
    def __train_encoder(self, imgs, labels):
        encoder_opt = optim.Adam(self.encoder.parameters(), lr=self.lr)
        encoder_opt.zero_grad()
        
        # Feedforward
        y_pred = self(imgs) 
        y_class = torch.sum(y_pred, 1)
        
        # Calculate losses and metrics
        class_pred_loss = torch.nn.NLLLoss()(torch.log(y_class), labels)
        kl_loss = KL_Loss(y_pred, m.classes) * self.risk_lambda
        encoder_loss = class_pred_loss + kl_loss
        
        # Backpropogation
        encoder_loss.backward()
        encoder_opt.step()
        
        return y_class, float(encoder_loss), float(kl_loss)
            
    def __train_classifier(self, imgs, labels):
        classifier_opt = optim.Adam(self.encoder.parameters(), lr=self.lr)
        classifier_opt.zero_grad()
        
        # Feedforward
        y_pred = self(imgs) 
        y_joint = torch.reshape(y_pred, (-1,2*m.classes))          

        # Calculate losses and metrics
        joint_pred_loss = torch.nn.NLLLoss()(torch.log(y_joint), labels)
        classifier_loss = joint_pred_loss
        
        # Backpropogation
        classifier_loss.backward()
        classifier_opt.step()
        
        return y_pred, float(classifier_loss)
    
    def __full_training(self, imgs, labels):
        opt = optim.Adam(self.parameters(), lr=self.lr)
        opt.zero_grad()
        
        # Feedforward
        y_pred = self(imgs) 
        
        # Calculate losses and metrics
        class_pred_loss = torch.nn.NLLLoss()(torch.log(y_pred), labels)
        loss = class_pred_loss
        
        # Backpropogation
        loss.backward()
        opt.step()
        
        return y_pred, float(loss)
    
    def __train_batch(self, imgs, labels):
        
        imgs = imgs.cuda()
        labels = labels.cuda()
        labels = labels.to(dtype=torch.int64)
        class_labels = labels[:,1]
        joint_labels = labels[:,0] * m.classes + labels[:,1]
        
        if self.use_KL:
            # Two step training of encoder and classifier with KL-loss
            _, encoder_loss, kl_loss = self.__train_encoder(imgs, class_labels)
            y_pred, classifier_loss = self.__train_classifier(imgs, joint_labels)
            
            y_class = torch.sum(y_pred, 1)
            y_joint = torch.reshape(y_pred, (-1,2*m.classes))
            
            pred_class_count = torch.count_nonzero(torch.argmax(y_class, 1) == class_labels)
            pred_joint_count = torch.count_nonzero(torch.argmax(y_joint, 1) == joint_labels)
            
            loss_metrics = {
                "loss": encoder_loss + classifier_loss,
                "encoder_loss": encoder_loss,
                "kl_loss": kl_loss,
                "classifier_loss": classifier_loss
            }
            
            acc_metrics = {
                "joint_acc": float(pred_joint_count),
                "class_acc": float(pred_class_count)
            }
            
        else:
            y_pred, loss = self.__full_training(imgs, class_labels)
            
            pred_class_count = torch.count_nonzero(torch.argmax(y_pred, 1) == class_labels)
            
            loss_metrics = {
                "loss": loss
            }
            
            acc_metrics = {
                "class_acc": float(pred_class_count)
            }

        return loss_metrics, acc_metrics
    
    def train(self, epochs, train_loader, val_loader=None, patience=None):
        
        train_size = len(train_loader.dataset)
        min_val_loss = np.inf
        best_model = None
        
        current_patience = patience
        
        # Perform training over number of epochs
        for i in range(epochs):
            print(f"Epoch {i+1}")
            
            loss_metrics = defaultdict(float)
            acc_metrics = defaultdict(float)
            
            for imgs, labels in train_loader:
                batch_loss, batch_acc = self.__train_batch(imgs, labels)

                for k, v in batch_loss.items():
                    loss_metrics[k] += v / train_size
                for k, v in batch_acc.items():
                    acc_metrics[k] += v / train_size
            
            self.__print_metrics(loss_metrics, precision=6)
            self.__print_metrics(acc_metrics, precision=4, space=2)
            
            # Run on validation set if provided
            if val_loader is not None:
                val_size = len(val_loader.dataset)
                loss_metrics, acc_metrics = self.evaluate(val_loader, val=True)
                
                # Early stopping
                if loss_metrics["loss"] < min_val_loss:
                    min_val_loss = loss_metrics["loss"]
                    best_model = copy.deepcopy(self.state_dict())
                    current_patience = patience
                else:
                    if current_patience is not None:
                        current_patience -= 1
                        if current_patience <= 0:
                            break
                        
        self.load_state_dict(best_model)
    
    def evaluate(self, loader, val=False):
        
        data_size = len(loader.dataset)
        loss_metrics = defaultdict(float)
        acc_metrics = defaultdict(float)
        
        for imgs, labels in loader:
            batch_loss, batch_acc = self.__evaluate_batch(imgs, labels)
            
            for k, v in batch_loss.items():
                loss_metrics[k] += v / data_size
            for k, v in batch_acc.items():
                acc_metrics[k] += v / data_size
                
        prefix = "val_" if val else ""
        
        self.__print_metrics(loss_metrics, prefix=prefix, precision=6)
        self.__print_metrics(acc_metrics, prefix=prefix, precision=4, space=2)
                
        return loss_metrics, acc_metrics
            
    def __evaluate_batch(self, imgs, labels):
        imgs = imgs.cuda()
        labels = labels.cuda()
        labels = labels.to(dtype=torch.int64)
        class_labels = labels[:,1]
        joint_labels = labels[:,0] * m.classes + labels[:,1]
        
        # Feedforward
        y_pred = self(imgs)
        
        if self.use_KL:           
            y_class = torch.sum(y_pred, 1)
            y_joint = torch.reshape(y_pred, (-1,2*m.classes))
            
            class_pred_loss = float(torch.nn.NLLLoss()(torch.log(y_class), class_labels))
            joint_pred_loss = float(torch.nn.NLLLoss()(torch.log(y_joint), joint_labels))
            kl_loss = float(KL_Loss(y_pred, m.classes) * self.risk_lambda)
            
            encoder_loss = class_pred_loss + kl_loss
            
            pred_class_count = torch.count_nonzero(torch.argmax(y_class, 1) == class_labels)
            
            loss_metrics = {
                "loss": encoder_loss,
                "kl_loss": kl_loss,
                "classifier_loss": class_pred_loss
            }
            
            acc_metrics = {
                "class_acc": float(pred_class_count)
            }
            
        else:          
            pred_class_count = torch.count_nonzero(torch.argmax(y_pred, 1) == class_labels)
            loss = float(torch.nn.NLLLoss()(torch.log(y_pred), class_labels))
            
            loss_metrics = {
                "loss": loss
            }
            
            acc_metrics = {
                "class_acc": float(pred_class_count)
            }

        return loss_metrics, acc_metrics

In [196]:
data_folder = "../Datasets/Experiment"
src_folder = os.path.join(data_folder, "Real World")
target_folder = os.path.join(data_folder, "Product")

d = DataGenerator(
    source_domain=src_folder,
    target_domain=target_folder,
    val_split=0.2,
    test_split=0.2,
    input_shape=(224,224),
    target_labels=0.01,
    target_train=False
)

train_label, train_nlabel = d.train_data()
val = d.val_data()
test = d.test_data()

ValueError: num_samples should be a positive integer value, but got num_samples=0

In [192]:
a = iter(val).next()
a[1]

tensor([[1., 2.],
        [1., 0.],
        [1., 0.],
        [1., 2.],
        [1., 0.],
        [1., 2.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 2.],
        [1., 1.],
        [1., 3.],
        [1., 1.],
        [1., 2.],
        [1., 1.],
        [1., 0.],
        [1., 2.],
        [1., 1.],
        [1., 1.],
        [1., 2.],
        [1., 3.],
        [1., 3.],
        [1., 1.],
        [1., 3.],
        [1., 0.],
        [1., 1.],
        [1., 2.],
        [1., 3.],
        [1., 2.],
        [1., 0.],
        [1., 1.],
        [1., 1.],
        [1., 0.],
        [1., 1.],
        [1., 2.],
        [1., 1.],
        [1., 2.],
        [1., 1.],
        [1., 0.],
        [1., 1.],
        [1., 3.],
        [1., 0.],
        [1., 2.],
        [1., 1.]])

In [193]:
m = SingleDomainModel(classes=len(d.classes), use_KL=False, risk_lambda=10, lr=1e-4)
_ = m.cuda()

Using cache found in /home/rhc37/.cache/torch/hub/pytorch_vision_v0.6.0


In [194]:
m.train(100, train_label, val, patience=5)

Epoch 1
loss: 0.044896 
class_acc: 0.2762 

val_loss: 0.026791 
val_class_acc: 0.5111 

Epoch 2
loss: 0.039572 
class_acc: 0.453 

val_loss: 0.023938 
val_class_acc: 0.4444 

Epoch 3
loss: 0.033005 
class_acc: 0.5912 

val_loss: 0.020886 
val_class_acc: 0.6444 

Epoch 4
loss: 0.022824 
class_acc: 0.7956 

val_loss: 0.016919 
val_class_acc: 0.7556 

Epoch 5
loss: 0.013666 
class_acc: 0.9006 

val_loss: 0.01712 
val_class_acc: 0.6889 

Epoch 6
loss: 0.008326 
class_acc: 0.9337 

val_loss: 0.010537 
val_class_acc: 0.8444 

Epoch 7
loss: 0.007756 
class_acc: 0.9392 

val_loss: 0.011855 
val_class_acc: 0.8222 

Epoch 8
loss: 0.003516 
class_acc: 0.9724 

val_loss: 0.012149 
val_class_acc: 0.8222 

Epoch 9
loss: 0.005581 
class_acc: 0.9337 

val_loss: 0.00993 
val_class_acc: 0.8667 

Epoch 10
loss: 0.002833 
class_acc: 0.9669 

val_loss: 0.018973 
val_class_acc: 0.8 

Epoch 11
loss: 0.002869 
class_acc: 0.989 

val_loss: 0.028355 
val_class_acc: 0.7556 

Epoch 12
loss: 0.003436 
class_acc: 0

In [195]:
m.evaluate(test)

loss: 0.011627 
class_acc: 0.8246 



(defaultdict(float, {'loss': 0.01162700799473545}),
 defaultdict(float, {'class_acc': 0.8245614035087719}))

In [12]:
epochs = 15

In [209]:
use_KL=True

for p in m.parameters():
    p.requires_grad = True

for i in range(epochs):
    print(f"Epoch {i+1}")
    
    loss, kl_loss, pred_loss, dis_loss, nlab_total = (0,0,0,0,0) 
    pred_count_joint, pred_count_class, dis_count, lab_total = (0,0,0,0)
    
    for batch_label, batch_nlabel in itertools.zip_longest(train_label, train_nlabel):
        
        if batch_label is not None:
            labels = batch_label[1]
            
            if batch_nlabel is not None:
                imgs = torch.cat((batch_label[0], batch_nlabel), 0)
            else:
                imgs = batch_label[0]
        else:
            imgs = batch_nlabel
            labels = None
                
        batch_results = m.run_batch(imgs, labels, use_KL=use_KL, rep_loss=rep_loss, risk_loss=risk_loss)
        loss += batch_results[0]
        kl_loss += batch_results[1]
        pred_loss += batch_results[2]
        dis_loss += batch_results[3]
        pred_count_joint += batch_results[4]
        pred_count_class += batch_results[5]
        dis_count += batch_results[6]
        
        if batch_label is not None:
            lab_total += batch_label[1].shape[0]
        if batch_nlabel is not None:
            nlab_total += batch_nlabel.shape[0]
    total = lab_total + nlab_total
        
    print(f"loss={loss/total}, "
          f"kl_loss={kl_loss/total}, " 
          f"pred_loss={pred_loss/total}, "
          f"dis_loss={dis_loss/total}, "
          f"joint_acc={pred_count_joint/lab_total}, "
          f"class_acc={pred_count_class/lab_total}, "
          f"dis_acc={dis_count/total}"
         )
    print()
    
    loss, kl_loss, pred_loss, dis_loss, total = (0,0,0,0,0) 
    pred_count_joint, pred_count_class, dis_count = (0,0,0)
    
    for data in val:    
        imgs, labels = data   
        batch_results = m.run_batch(imgs, labels, training=False, use_KL=use_KL, rep_loss=rep_loss, risk_loss=risk_loss)
        loss += batch_results[0]
        kl_loss += batch_results[1]
        pred_loss += batch_results[2]
        dis_loss += batch_results[3]
        pred_count_joint += batch_results[4]
        pred_count_class += batch_results[5]
        dis_count += batch_results[6]
        total += data[1].shape[0]  
            
    print(f"val_loss={loss/total}, "
          f"val_kl_loss={kl_loss/total}, " 
          f"val_pred_loss={pred_loss/total}, "
          f"val_dis_loss={dis_loss/total}, "
          f"val_joint_acc={pred_count_joint/total}, "
          f"val_class_acc={pred_count_class/total}, "
          f"val_dis_acc={dis_count/total}"
         )
    print()

    
    loss, kl_loss, pred_loss, dis_loss, total = (0,0,0,0,0) 
    pred_count_joint, pred_count_class, dis_count = (0,0,0)
    for data in test:
        imgs, labels = data
        batch_results = m.run_batch(imgs, labels, training=False, use_KL=use_KL, rep_loss=rep_loss, risk_loss=risk_loss)
        loss += batch_results[0]
        kl_loss += batch_results[1]
        pred_loss += batch_results[2]
        dis_loss += batch_results[3]
        pred_count_joint += batch_results[4]
        pred_count_class += batch_results[5]
        dis_count += batch_results[6]
        total += data[1].shape[0]     
        
    print(f"test_loss={loss/total}, "
          f"test_kl_loss={kl_loss/total}, " 
          f"test_pred_loss={pred_loss/total}, "
          f"test_dis_loss={dis_loss/total}, "
          f"test_joint_acc={pred_count_joint/total}, "
          f"test_class_acc={pred_count_class/total}, "
          f"test_dis_acc={dis_count/total}"
         )
    print()

Epoch 1
loss=0.0448097315912823, kl_loss=0.0031949610505652474, pred_loss=0.029040276423299986, dis_loss=0.012574494669311926, joint_acc=0.21428571428571427, class_acc=0.37714285714285717, dis_acc=0.5984405458089669

val_loss=0.08116570048862033, val_kl_loss=0.0034040855036841498, val_pred_loss=0.030469600359598795, val_dis_loss=0.04729201528761122, val_joint_acc=0.044444444444444446, val_class_acc=0.28888888888888886, val_dis_acc=0.0

test_loss=0.0665176458526076, test_kl_loss=0.004598480044749745, test_pred_loss=0.02362041515216493, test_dis_loss=0.03829875326993173, test_joint_acc=0.03508771929824561, test_class_acc=0.40350877192982454, test_dis_acc=0.0

Epoch 2
loss=0.046705354026883666, kl_loss=0.002489541328673707, pred_loss=0.028486819527534946, dis_loss=0.015728992328309176, joint_acc=0.22285714285714286, class_acc=0.3657142857142857, dis_acc=0.5458089668615984

val_loss=0.0897398206922743, val_kl_loss=0.005053795377413432, val_pred_loss=0.02954330179426405, val_dis_loss=0.0551

loss=0.024398054237718934, kl_loss=0.00020799190613856905, pred_loss=0.005284243922310266, dis_loss=0.018905818375230532, joint_acc=0.8828571428571429, class_acc=0.9314285714285714, dis_acc=0.6257309941520468

val_loss=0.09120824601915148, val_kl_loss=0.0002967918084727393, val_pred_loss=0.037303956349690755, val_dis_loss=0.05360750092400445, val_joint_acc=0.0, val_class_acc=0.6222222222222222, val_dis_acc=0.0

test_loss=0.07948094083551775, test_kl_loss=0.0002799253061152341, test_pred_loss=0.03591607746325041, test_dis_loss=0.043284934863709566, test_joint_acc=0.0, test_class_acc=0.5614035087719298, test_dis_acc=0.0

Epoch 15
loss=0.02136706928295931, kl_loss=8.285282108003227e-05, pred_loss=0.004353846802755639, dis_loss=0.01693036966388918, joint_acc=0.8857142857142857, class_acc=0.9342857142857143, dis_acc=0.6042884990253411

val_loss=0.08339019351535373, val_kl_loss=0.00023287754091951582, val_pred_loss=0.03308609591590034, val_dis_loss=0.050071223576863604, val_joint_acc=0.0, va

In [38]:
class DANN(nn.Module):
    def __init__(
        self, 
        classes=65
    ):
        super(DANN, self).__init__()
        self.classes = classes
        self.encoder = build_encoder()
        self.classifier = build_classifier(512, self.classes, False)
        self.discriminator = build_discriminator(512)
    
    def forward(self, inputs, rep_loss=1):
        features = self.encoder(inputs)
        classes = self.classifier(features)
        domains = self.discriminator(ReverseLayer.apply(features, rep_loss))
        return classes, domains
    
    def run_batch(self, imgs, labels=None, training=True, rep_loss=1):
        if training:
            opt = optim.Adam(self.parameters(), lr=1e-4)
            opt.zero_grad()
            
        class_pred_loss = 0
        pred_count_class = 0
        
        imgs = imgs.cuda()
        y_pred, domain_pred_label = self(imgs) 
        
        if labels is not None:
            labels = labels.cuda()
            labels = labels.to(dtype=torch.int64)
            class_labels = labels[:,1]
            domain_labels = torch.cat((labels[:,0], torch.ones(imgs.shape[0] - labels.shape[0], dtype=torch.long).cuda()), 0)

            y_pred = y_pred[:labels.shape[0]]

            class_pred_loss = torch.nn.NLLLoss()(torch.log(y_pred), class_labels)
            pred_count_class = int(torch.count_nonzero(torch.argmax(y_pred, 1) == class_labels))
             
        dis_loss = torch.nn.NLLLoss()(torch.log(domain_pred_label), domain_labels) * rep_loss
        dis_count = int(torch.count_nonzero(torch.argmax(domain_pred_label, 1) == domain_labels))
        
        loss = class_pred_loss + dis_loss

        if training:
            loss.backward()
            opt.step()
        
        return (
            float(loss),  
            float(class_pred_loss), 
            float(dis_loss), 
            float(pred_count_class),
            float(dis_count)
        )

In [39]:
dann = DANN(classes=len(d.classes))
_ = dann.cuda()

Using cache found in /home/rhc37/.cache/torch/hub/pytorch_vision_v0.6.0


In [40]:
# rep_loss=1

for p in dann.parameters():
    p.requires_grad = True    

for epoch in range(epochs):
    
    print(f"Epoch {epoch+1}")
    
    loss, pred_loss, dis_loss, nlab_total = (0,0,0,0) 
    pred_count_class, dis_count, lab_total = (0,0,0)
    len_dataloader = max(len(train_label), len(train_nlabel))
    
    for i, (batch_label, batch_nlabel) in enumerate(itertools.zip_longest(train_label, train_nlabel)):
        
        p = float(i + epoch * len_dataloader) / epochs / len_dataloader
        rep_loss=2. / (1. + np.exp(-10 * p)) - 1
        
        if batch_label is not None:
            labels = batch_label[1]
            
            if batch_nlabel is not None:
                imgs = torch.cat((batch_label[0], batch_nlabel), 0)
            else:
                imgs = batch_label[0]
        else:
            imgs = batch_nlabel
            labels = None
                
        batch_results = dann.run_batch(imgs, labels, rep_loss=rep_loss)
        loss += batch_results[0]
        pred_loss += batch_results[1]
        dis_loss += batch_results[2]
        pred_count_class += batch_results[3]
        dis_count += batch_results[4]
        
        if batch_label is not None:
            lab_total += batch_label[1].shape[0]
        if batch_nlabel is not None:
            nlab_total += batch_nlabel.shape[0]
    total = lab_total + nlab_total
        
    print(f"loss={loss/total}, "
          f"pred_loss={pred_loss/total}, "
          f"dis_loss={dis_loss/total}, "
          f"class_acc={pred_count_class/lab_total}, "
          f"dis_acc={dis_count/total}"
         )
    print()
    
    loss, pred_loss, dis_loss, total = (0,0,0,0) 
    pred_count_class, dis_count = (0,0)
    
    for data in val:    
        imgs, labels = data   
        batch_results = dann.run_batch(imgs, labels, training=False, rep_loss=rep_loss)
        loss += batch_results[0]
        pred_loss += batch_results[1]
        dis_loss += batch_results[2]
        pred_count_class += batch_results[3]
        dis_count += batch_results[4]
        total += data[1].shape[0]  
            
    print(f"val_loss={loss/total}, "
          f"val_pred_loss={pred_loss/total}, "
          f"val_dis_loss={dis_loss/total}, "
          f"val_class_acc={pred_count_class/total}, "
          f"val_dis_acc={dis_count/total}"
         )
    print()

    
    loss, pred_loss, dis_loss, total = (0,0,0,0) 
    pred_count_class, dis_count = (0,0)
    for data in test:
        imgs, labels = data
        batch_results = dann.run_batch(imgs, labels, training=False, rep_loss=rep_loss)
        loss += batch_results[0]
        pred_loss += batch_results[1]
        dis_loss += batch_results[2]
        pred_count_class += batch_results[3]
        dis_count += batch_results[4]
        total += data[1].shape[0]       
        
    print(f"test_loss={loss/total}, "
          f"test_pred_loss={pred_loss/total}, "
          f"test_dis_loss={dis_loss/total}, "
          f"test_class_acc={pred_count_class/total}, "
          f"test_dis_acc={dis_count/total}"
         )
    print()

Epoch 1
loss=0.029174270220899676, pred_loss=0.028099820627803692, dis_loss=0.00107444961125042, class_acc=0.36826347305389223, dis_acc=0.6510721247563352

val_loss=0.05201121436225043, val_pred_loss=0.02787361145019531, val_dis_loss=0.02413760291205512, val_class_acc=0.4222222222222222, val_dis_acc=0.0

test_loss=0.04142024642542789, test_pred_loss=0.02256692919814796, test_dis_loss=0.018853317227279932, test_class_acc=0.45614035087719296, test_dis_acc=0.0

Epoch 2
loss=0.03042834834513376, pred_loss=0.02361818038464522, dis_loss=0.006810167938703217, class_acc=0.5029940119760479, dis_acc=0.631578947368421

val_loss=0.07134602864583334, val_pred_loss=0.026157898373074, val_dis_loss=0.04518813027275933, val_class_acc=0.4888888888888889, val_dis_acc=0.0

test_loss=0.05892959812231231, test_pred_loss=0.023572536936977452, test_dis_loss=0.03535706118533486, test_class_acc=0.47368421052631576, test_dis_acc=0.0

Epoch 3
loss=0.030507138017092997, pred_loss=0.019705610317096375, dis_loss=0.0

In [81]:
b.shape

torch.Size([64, 2])

In [82]:
labels.shape

torch.Size([32, 2])

In [44]:
loss, kl_loss, pred_loss, dis_loss, total = (0,0,0,0,0) 
pred_count_joint, pred_count_class, dis_count = (0,0,0)
for data in test:
    batch_results = m.run_batch(data, training=False, use_KL=use_KL, rep_loss=rep_loss)
    loss += batch_results[0]
    kl_loss += batch_results[1]
    pred_loss += batch_results[2]
    dis_loss += batch_results[3]
    pred_count_joint += batch_results[4]
    pred_count_class += batch_results[5]
    dis_count += batch_results[6]
    total += data[1].shape[0]

print(f"test_loss={loss/total}, "
      f"test_kl_loss={kl_loss/total}, " 
      f"test_pred_loss={pred_loss/total}, "
      f"test_dis_loss={dis_loss/total}, "
      f"test_joint_acc={pred_count_joint/total}, "
      f"test_class_acc={pred_count_class/total}, "
      f"test_dis_acc={dis_count/total}"
     )

test_loss=0.04912133802447403, test_kl_loss=0.0032100612134264225, test_pred_loss=0.023766850170336272, test_dis_loss=0.022144424287896407, test_joint_acc=0.17543859649122806, test_class_acc=0.24561403508771928, test_dis_acc=0.0


In [45]:
for data in test:
    imgs, labels = data
    break

In [47]:
for batch_label, batch_nlabel in itertools.zip_longest(train_label, train_nlabel):
    imgs_label , labels = batch_label
    imgs_nlabel = batch_nlabel
    break

In [51]:
_, domain_pred = m(imgs_nlabel.cuda())
domain_pred

tensor([0.2744, 0.2950, 0.2825, 0.3014, 0.2855, 0.2858, 0.2895, 0.2864, 0.2892,
        0.2879, 0.2831, 0.2841, 0.2889, 0.2999, 0.2810, 0.2836, 0.2994, 0.2932,
        0.2896, 0.2690, 0.2878, 0.2886, 0.2119, 0.2940, 0.2667, 0.2445, 0.2914,
        0.2854, 0.2884, 0.2849, 0.2927, 0.2858], device='cuda:0',
       grad_fn=<SqueezeBackward0>)

In [47]:
torch.argmax(torch.sum(y_pred, 1), 1).shape

torch.Size([57])

In [46]:
labels[:,1].shape

torch.Size([57])

In [49]:
torch.sum(labels, axis=1)

tensor([2., 2., 2., 3., 4., 4., 2., 2., 1., 3., 2., 4., 4., 2., 2., 2., 2., 2.,
        1., 3., 2., 1., 3., 4., 1., 3., 4., 1., 4., 3., 2., 1., 3., 4., 4., 4.,
        1., 3., 4., 3., 2., 3., 2., 2., 1., 3., 3., 3., 4., 1., 1., 1., 1., 4.,
        4., 2., 4.])

In [52]:
torch.count_nonzero(torch.argmax(torch.sum(y_pred, 1), 1) == labels[:,1].cuda()) / labels.shape[0]

tensor(0.4386, device='cuda:0')

In [237]:
y_pred = pred[:1]
y_joint = torch.reshape(y_pred, (1, 2*4))
    
y_class = torch.unsqueeze(torch.sum(y_pred, 1), 1)
y_domain = torch.unsqueeze(torch.sum(y_pred, 2), -1)
    
    
y_ind_joint = torch.reshape((y_domain * y_class), (1,2*4))

In [238]:
y_pred

tensor([[[0.2102, 0.3443, 0.0579, 0.0241],
         [0.0946, 0.2194, 0.0309, 0.0186]]], device='cuda:0',
       grad_fn=<SliceBackward>)

In [239]:
y_joint

tensor([[0.2102, 0.3443, 0.0579, 0.0241, 0.0946, 0.2194, 0.0309, 0.0186]],
       device='cuda:0', grad_fn=<ViewBackward>)

In [240]:
y_class

tensor([[[0.3048, 0.5637, 0.0888, 0.0427]]], device='cuda:0',
       grad_fn=<UnsqueezeBackward0>)

In [241]:
y_domain

tensor([[[0.6365],
         [0.3635]]], device='cuda:0', grad_fn=<UnsqueezeBackward0>)

In [242]:
y_ind_joint

tensor([[0.1940, 0.3588, 0.0565, 0.0272, 0.1108, 0.2049, 0.0323, 0.0155]],
       device='cuda:0', grad_fn=<ViewBackward>)

In [219]:
torch.nn.KLDivLoss(log_target=True, reduction="sum")(
    torch.log(y_joint), 
    torch.log(y_ind_joint)
)

tensor(inf)

In [257]:
def kl_divergence(p, q):
    p = p.reshape(-1)
    q = q.reshape(-1)
    print(p, q)
    return sum(p[i] * torch.log2(p[i]/q[i]) for i in range(len(p)))

In [258]:
kl_divergence(y_joint, y_ind_joint)

tensor([0.2102, 0.3443, 0.0579, 0.0241, 0.0946, 0.2194, 0.0309, 0.0186],
       device='cuda:0', grad_fn=<ViewBackward>) tensor([0.1940, 0.3588, 0.0565, 0.0272, 0.1108, 0.2049, 0.0323, 0.0155],
       device='cuda:0', grad_fn=<ViewBackward>)


tensor(0.0046, device='cuda:0', grad_fn=<AddBackward0>)