<a href="https://colab.research.google.com/github/HiroTakeda/Notes/blob/main/Note05_DomainAdaptation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Domain Adaptation

+ Related topics: __Regression__, __Logistic Regression__, __Neural Network__, __Generative Adversarial Net__

Domain Adaptation using neural network was presented by Ganin *et al.* in 2016 [[paper](https://www.jmlr.org/papers/volume17/15-239/15-239.pdf)]. It is another important method for a variety of different problems. This short note briefly explains the proposed method.

The peformance of the parametric approach strongly depends on the training dataset. When addressing the real world problems, we usually begin with collecting training data. Since data labeling is time-consuming, one might think whether it's possible to use existing datasets or to create datasets using computer graphics. However, the trained model with an existing dataset may not work well on the unlabelled dataset of our interest because the input data modality is different. For example, labeled images are all grayscale while unlabeled images are color. The domain adaptation technique with NN proposed by Ganin *et al.* allows us to train a model for unlabeled data using labled data that can be in a different form. The block diagram below shows the basic domain adaptation network architecture, in which the souce data are labeled and the target data are unlabeled. 

![picture](https://drive.google.com/uc?id=1NZjcc6iGaSLxLBs1zPpEVHAk7cojn5FG)

The objective of domain adaptation may be stated as:

> Given a labelled (source) dataset and an unlabelled (target) dataset that can be in a different form, find a label predictor that is valid for both datasets.



## Network Model

The basic domain adaptation model shown above consists of 3 components: feature extractor, label predictor, and domain classifier. The feature extractor converts the input to a feature vector, the label predictor estimates the class label using the feature vector, and the domain classifier classifies the feature vector whether it comes from either a source (labeled) or a target (unlabeled) input. Without the domain classifier, the netowrk model is nothing but a regular multinomial logistic regression.


The objective is to have a label predctor that is valid for both souce (labeled) and target (unlabelled) data, but it is impossible to train the label predictor for target data without labels. One thing we can do is to match the probability distibutions of the features of source and target data: $p(F(\mathbf{z})) = p(F(\mathbf{x}))$. If they match, the label predictor works with unlabelled target data as well. The idea of probability distribution matching reminds us of __Generative Adversarial Net__ (GAN). The domain classifier is in fact identical to the discrminator of GAN, and therefore the opeimization process is similart to GAN.

## Optimization

The dommain adaptation network consists of 3 components and we need to optimize them with different criteria. Usually the network is gradually optimized with the stochastic gradient method part by part as shown below.

### Step 1: Label Predictor
![picture](https://drive.google.com/uc?id=1YhSMMO2vFgpCI04quv8pPuZlcDGFZM3A)

First, we optimize the feature extractor $F$ and the label predictor $L$ using the source (labeled) data. The combined network $L(F(\cdot))$ is a multinomial logistic regression function and we can obtain the combined network by maximizing the log-likelihood function:

$$
\displaystyle\max_{L, F} \displaystyle\sum_{n} \ln p(\,l_n\, |\, L(F(\mathbf{z}_n))\,).
$$

### Step 2: Domain Classifier with source data
![picture](https://drive.google.com/uc?id=1Whmju4wH7F-HLgTZle4kHJB2JM-tqKF2)

Next, we optimize the domain classifier $D$ with $F$ fixed using the source data. Because we use the source data, the labels $t_n$ are $1$ for all $n$. Again the domain classifier is equivalent to the discriminator of GAN, and it is a binomial logistic regression function. We optimize $D$ by maximizing the log-likelihood function:

$$
\displaystyle\max_{D} \displaystyle\sum_{n} t_n \ln p(\,t_n \, |\, D(F(\mathbf{z}_n)) \,) \color{lightgray}{+ \displaystyle\sum_{n} (1-t_n) \ln p(\,t_n\, | \,D(F(\mathbf{x}_n)) \,) },
$$

with $t_n=1$ (source).

### Step 3: Domain Classifier with target data
![picture](https://drive.google.com/uc?id=1u8Hzmn5Fa4u7zSsXIpyQhXY4J3x1TjS2)

Also we optimize the domain classifier $D$ with $F$ fixed using the target data. Because we use the target data, the labels $t_n$ are 0 for all n this time. We can use the same log-likelihood function for the optimization:

$$
\displaystyle\max_{D} \color{lightgray}{\displaystyle\sum_{n} t_n \ln p(\,t_n\, |\, D(F(\mathbf{z}_n)) \,)} \color{black}{+ \displaystyle\sum_{n} (1-t_n) \ln p(\,t_n\, |\, D(F(\mathbf{x}_n)) \,),}
$$

with $t_n=0$ (target).


### Step 4: Feature extractor with target data
![picture](https://drive.google.com/uc?id=1xs9On1pinlEe0Mfb2nWvM3AEC-8ZWN2p)

Finally, we optimize $F$ with $D$ fixed. We'd like the domain classifier $D$ to unwillingly classify the target feature $F(\mathbf{x})$ as source so that the probability distributions of $F(\mathbf{x})$ and $F(\mathbf{z})$ match. As a result, the label predictor $L$ is valid for target data as well. To do that, we simply flip the label value $t_n$ to 1 (source), and then maximize the follwing log-likelihood function:

$$
\displaystyle\max_{F} \displaystyle\sum_n t_n \ln p(\, t_n \, | \, D(F(\mathbf{x}_n)) \, ),
$$

with $t_n=1$ (we want $F(\mathbf{x})$ to be classifed as source).


## Inference

![picture](https://drive.google.com/uc?id=1TU89EsaNG-UfLLoeR67VXlLl49--VAMi)

Once the training is done, we no longer need the domain classifier, and the combined model of the feature extractor and the label predictor classifies the new input $\mathbf{x}$.

## Example

Below shows an example of the domain adaptation. This example uses 
MNIST dataset [[link](http://yann.lecun.com/exdb/mnist/)] as source (labeled) and MNIST-M datasets [[link](https://drive.google.com/drive/folders/0B_tExHiYS-0vR2dNZEU4NGlSSW8?resourcekey=0-Rs-0pTFZmKp_I1HoBkbiug)] as target (unlabeled) data.



In [None]:
import os
import numpy
import pylab
import struct
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# MNIST dataset
class MNISTdataset(Dataset):
  def __init__(self, path_root, mode='train', transforms=None):
    
    self.transforms = transforms
    
    if mode == 'train':
      path_images = '%s\\train-images.idx3-ubyte' % (path_root)
      path_labels = '%s\\train-labels.idx1-ubyte' % (path_root)
      num_images = 60000
    else:
      path_images = '%s\\t10k-images.idx3-ubyte' % (path_root)
      path_labels = '%s\\t10k-labels.idx1-ubyte' % (path_root)
      num_images = 10000
      
    with open(path_images, 'rb') as f:
      self.images = numpy.zeros((28, 28, num_images), dtype=numpy.float32)
      b = f.read(16)
      for n in range(num_images):
        b = f.read(28*28)
        b = struct.unpack('>%dB' % (28*28), b)
        self.images[:,:,n] = numpy.array(b).reshape(28, 28).astype(numpy.float32) / 255.0
  
    with open(path_labels, 'rb') as f:
      b = f.read(8)
      b = f.read(num_images)
      b = struct.unpack('>%dB' % (num_images), b)
      self.labels = numpy.array(b)
  
  def __len__(self):
    return self.labels.shape[0]
  
  def __getitem__(self, index):
    img = self.images[:,:,index].reshape(28, 28, 1)
    img = numpy.dstack((img, img, img)) # (28 x 28 x 1) --> (28 x 28 x 3)
    img = numpy.pad(img, ((2,2), (2,2), (0,0)), 'symmetric') # (28 x 28 x 3) --> (32 x 32 x 3)
    
    lbl = self.labels[index].astype(numpy.int64)
    
    sample = {'image': img, 'label': lbl}
    if self.transforms is not None:
      sample = self.transforms(sample)
      
    return sample

# MNIST-M dataset
class MNISTMdataset(Dataset):
  def __init__(self, path_root, mode='train', transforms=None):
    
    self.transforms = transforms
    
    # MNIST-M dataset
    if mode == 'train':
      path_images = '%s\\mnist_m_train' % (path_root)
      path_labels = '%s\\mnist_m_train_labels.txt' % (path_root)
    else:
      path_images = '%s\\mnist_m_test' % (path_root)
      path_labels = '%s\\mnist_m_test_labels.txt' % (path_root)
      
    with open(path_labels, 'rt') as f:
      line = f.readline()
      num_images = 0
      while line:
        parts = line.split(' ')
        if os.path.exists('%s\\%s' % (path_images, parts[0])):
          num_images += 1
        line = f.readline()
    
    with open(path_labels, 'rt') as f:
      self.images = numpy.zeros((32, 32, 3, num_images), dtype=numpy.float32)
      self.labels = numpy.zeros(num_images, dtype=numpy.int)
      line = f.readline()
      n = 0
      while line:
        parts = line.split(' ')
        if os.path.exists('%s\\%s' % (path_images, parts[0])):
          self.images[:,:,:,n] = pylab.imread('%s\\%s' % (path_images, parts[0])).astype(numpy.float32)
          if numpy.max(self.images[:,:,:,n]) > 1.0:
            self.images[:,:,:,n] /= 255.0
          self.labels[n] = int(parts[1])
          n += 1
        line = f.readline()
    
  
  def __len__(self):
    return self.labels.shape[0]
  
  def __getitem__(self, index):
    
    img = self.images[:,:,:,index].reshape(32, 32, 3)
    
    lbl = self.labels[index].astype(numpy.int64)
    
    sample = {'image': img, 'label': lbl}
    if self.transforms is not None:
      sample = self.transforms(sample)
      
    return sample
  
class myToTensor(object):
  def __init__(self):
    self.ToTensor = transforms.ToTensor()
    
  def __call__(self, sample):
    sample['image'] = self.ToTensor(sample['image'])
    return sample


# Feature Extractor
class FeatureExtractor(nn.Module):
  def __init__(self, in_channels, out_channels, nf, ngpu):
    super(FeatureExtractor, self).__init__()
    self.ngpu = ngpu
    self.main = nn.Sequential(
      # in_channels x 32 x 32
      nn.Conv2d(in_channels=in_channels, out_channels=nf, kernel_size=4, stride=2, padding=1, bias=False),
      nn.BatchNorm2d(nf),
      nn.ReLU(inplace=True),
      # nf x 16 x 16
      nn.Conv2d(in_channels=nf, out_channels=nf, kernel_size=3, stride=1, padding=1, bias=False),
      nn.BatchNorm2d(nf),
      nn.ReLU(inplace=True),
      # nf x 16 x 16
      nn.Conv2d(in_channels=nf, out_channels=nf, kernel_size=3, stride=1, padding=1, bias=False),
      nn.BatchNorm2d(nf),
      nn.ReLU(inplace=True),
      # nf x 16 x 16
      nn.Conv2d(in_channels=nf, out_channels=nf, kernel_size=3, stride=1, padding=1, bias=False),
      nn.BatchNorm2d(nf),
      nn.ReLU(inplace=True),
      # nf x 16 x 16
      nn.Conv2d(in_channels=nf, out_channels=nf, kernel_size=3, stride=1, padding=1, bias=False),
      nn.BatchNorm2d(nf),
      nn.ReLU(inplace=True),
      # nf x 16 x 16
      nn.Conv2d(in_channels=nf, out_channels=nf, kernel_size=3, stride=1, padding=1, bias=False),
      nn.BatchNorm2d(nf),
      nn.ReLU(inplace=True),
      # nf x 16 x 16
      nn.Conv2d(in_channels=nf, out_channels=nf, kernel_size=3, stride=1, padding=1, bias=False),
      nn.BatchNorm2d(nf),
      nn.ReLU(inplace=True),
      # nf x 16 x 16
      nn.Conv2d(in_channels=nf, out_channels=nf*2, kernel_size=4, stride=2, padding=1, bias=False),
      nn.BatchNorm2d(nf*2),
      nn.ReLU(inplace=True),
      # nf*2 x 8 x 8
      nn.Conv2d(in_channels=nf*2, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False),
      nn.BatchNorm2d(nf*2),
      nn.ReLU(inplace=True),
      # nf*2 x 8 x 8
    )
    
  def forward(self, x):
    return self.main(x)


# Label Predictor
class LabelPredictor(nn.Module):
  def __init__(self, in_channels, out_channels, nf, ngpu):
    super(LabelPredictor, self).__init__()
    self.ngpu = ngpu
    self.main = nn.Sequential(
      # in_channels x 8 x 8
      nn.Conv2d(in_channels=in_channels, out_channels=nf, kernel_size=4, stride=2, padding=1, bias=False),
      nn.BatchNorm2d(nf),
      nn.ReLU(inplace=True),
      # nf x 4 x 4
      nn.Conv2d(in_channels=nf, out_channels=nf, kernel_size=3, stride=1, padding=1, bias=False),
      nn.BatchNorm2d(nf),
      nn.ReLU(inplace=True),
      # nf x 4 x 4
      nn.Conv2d(in_channels=nf, out_channels=out_channels, kernel_size=4, stride=1, padding=0, bias=False),
      nn.LogSoftmax(dim=1)
      # out_channels x 1 x 1
    )
    
  def forward(self, x):
    return self.main(x)


# Domain Classifier
class DomainClassifier(nn.Module):
  def __init__(self, in_channels, nf, ngpu):
    super(DomainClassifier, self).__init__()
    self.ngpu = ngpu
    self.main = nn.Sequential(
      # in_channels x 8 x 8
      nn.Conv2d(in_channels=in_channels, out_channels=nf, kernel_size=4, stride=2, padding=1, bias=False),
      nn.BatchNorm2d(nf),
      nn.ReLU(inplace=True),
      # nf x 4 x 4
      nn.Conv2d(in_channels=nf, out_channels=nf, kernel_size=3, stride=1, padding=1, bias=False),
      nn.BatchNorm2d(nf),
      nn.ReLU(inplace=True),
      # nf x 4 x 4
      nn.Conv2d(in_channels=nf, out_channels=nf, kernel_size=3, stride=1, padding=1, bias=False),
      nn.BatchNorm2d(nf),
      nn.ReLU(inplace=True),
      # nf x 4 x 4
      nn.Conv2d(in_channels=nf, out_channels=nf, kernel_size=3, stride=1, padding=1, bias=False),
      nn.BatchNorm2d(nf),
      nn.ReLU(inplace=True),
      # nf x 4 x 4
      nn.Conv2d(in_channels=nf, out_channels=2, kernel_size=4, stride=1, padding=0, bias=False),
      nn.LogSoftmax(dim=1),
      # 1 x 1 x 1
    )
    
  def forward(self, x):
    return self.main(x)


def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv') != -1:
    nn.init.normal_(m.weight.data, 0.0, 0.02)
  elif classname.find('BatchNorm') != -1:
    nn.init.normal_(m.weight.data, 1.0, 0.02)
    nn.init.constant_(m.bias.data, 0)


def saveConfusionMatrix(fname, labels_true, labels_predicted):
  pylab.ion()
  cnf = confusion_matrix(labels_true, labels_predicted, normalize='true')
  cnf = numpy.round(cnf*100, 0)
  cmd = ConfusionMatrixDisplay(cnf)
  cmd.plot()
  pylab.show(block=False)
  pylab.pause(1)
  pylab.savefig(fname)
  pylab.close()
  return


def main():
  
  path_mnist = 'F:\\datasets\\mnist'
  path_mnistm = 'F:\\datasets\\mnist-m'
  batch_size = 64
  num_workers = 2
  in_channels_FE = 3
  out_channels_FE = 16
  num_features_FE = 8
  ngpu = 1
  num_classes = 10
  num_features_LP = num_features_FE * 2
  in_channels_LP = out_channels_FE
  in_channels_DC = in_channels_LP
  num_features_DC = num_features_LP
  learning_rate = 1e-4
  num_epochs = 200
  gamma = 0.99
  
  mnist_cnf_fig = 'F:\\MNIST_cnf.jpg'
  mnistm_cnf_fig = 'F:\\MNIST-M_cnf.jpg'
  loss_curve_fig = 'F:\\loss.jpg'
  
  if torch.cuda.is_available() and ngpu > 0:
    device = torch.device('cuda:0')
  else:
    device = torch.device('cpu')
  
  mnist_train_transforms = transforms.Compose([
      myToTensor(),
    ])
  
  mnist_eval_transforms = transforms.Compose([
      myToTensor(),
    ])
  
  mnistm_train_transforms = transforms.Compose([
      myToTensor(),
    ])
  
  mnistm_eval_transforms = transforms.Compose([
      myToTensor(),
    ])
  
  mnist_train_dataset = MNISTdataset(path_mnist, mode='train', transforms=mnist_train_transforms)
  mnist_eval_dataset = MNISTdataset(path_mnist, mode='eval', transforms=mnist_eval_transforms)
  mnistm_train_dataset = MNISTMdataset(path_mnistm, mode='train', transforms=mnistm_train_transforms)
  mnistm_eval_dataset = MNISTMdataset(path_mnistm, mode='eval', transforms=mnistm_eval_transforms)
  
  mnist_train_dataloader = DataLoader(dataset=mnist_train_dataset,
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=num_workers)
  
  mnist_eval_dataloader = DataLoader(dataset=mnist_eval_dataset,
                                     batch_size=batch_size,
                                     shuffle=False,
                                     num_workers=num_workers)
  
  mnistm_train_dataloader = DataLoader(dataset=mnistm_train_dataset,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=num_workers)
  
  mnistm_eval_dataloader = DataLoader(dataset=mnistm_eval_dataset,
                                      batch_size=batch_size,
                                      shuffle=False,
                                      num_workers=num_workers)
  
  modelFE = FeatureExtractor(in_channels=in_channels_FE, out_channels=out_channels_FE, nf=num_features_FE, ngpu=ngpu)
  modelLP = LabelPredictor(in_channels=in_channels_LP, out_channels=num_classes, nf=num_features_LP, ngpu=ngpu)
  modelDC = DomainClassifier(in_channels=in_channels_DC, nf=num_features_DC, ngpu=ngpu)
  
  modelFE.apply(weights_init)
  modelLP.apply(weights_init)
  modelDC.apply(weights_init)
  
  modelFE.to(device)
  modelLP.to(device)
  modelDC.to(device)
  
  criterionLP = nn.NLLLoss()
  criterionDC = nn.NLLLoss()
  
  optimizerFE = torch.optim.SGD(modelFE.parameters(), lr=learning_rate, weight_decay=1e-8)
  optimizerLP = torch.optim.SGD(modelLP.parameters(), lr=learning_rate, weight_decay=1e-8)
  optimizerDC = torch.optim.SGD(modelDC.parameters(), lr=learning_rate, weight_decay=1e-8)
  
  schedulerFE = torch.optim.lr_scheduler.ExponentialLR(optimizerFE, gamma=gamma)
  schedulerLP = torch.optim.lr_scheduler.ExponentialLR(optimizerLP, gamma=gamma)
  schedulerDC = torch.optim.lr_scheduler.ExponentialLR(optimizerDC, gamma=gamma)
  
  source_label = 1
  target_label = 0
  
  LP_train_losses = []
  LP_eval_losses_1 = []
  LP_eval_losses_2 = []
  DC_train_losses_1 = []
  DC_train_losses_2 = []
  for epoch in range(num_epochs):
    
    # training
    modelFE.train()
    modelLP.train()
    LP_loss = 0
    DC_loss_1 = 0
    DC_loss_2 = 0
    mnist_train_iter = iter(mnist_train_dataloader)
    for i, mnistm in enumerate(mnistm_train_dataloader, 0):
      
      # target batch
      image_mnistm = mnistm['image'].to(device)
      label_mnistm = mnistm['label'].to(device)
      
      # source batch
      mnist = next(mnist_train_iter) #next(iter(mnist_train_dataloader))
      image_mnist = mnist['image'].to(device)
      label_mnist = mnist['label'].to(device)
      
      ######################################################
      # Step 1: train the feature extractor and the label predictor with source data
      modelFE.zero_grad()
      modelLP.zero_grad()
      f = modelFE(image_mnist)
      y = modelLP(f)
      y = y.reshape(y.shape[:2])
      err = criterionLP(y, label_mnist)
      err.backward()
      optimizerFE.step()
      optimizerLP.step()
      
      LP_loss += err.item() / len(label_mnist)
      
      if i % 100 == 0:
        print(err.item())
      
      #######################################################
      # Step 2: train the domain classifier with source data
      modelFE.zero_grad()
      modelDC.zero_grad()
      b_size = label_mnist.shape[0]
      label = torch.full((b_size,), source_label, dtype=torch.long, device=device)
      f = modelFE(image_mnist)
      y = modelDC(f)
      y = y.reshape(y.shape[:2])
      err1 = criterionDC(y, label)
      
      ########################################################
      # Step 3: train the domain classifier with target data
      b_size = label_mnistm.shape[0]
      label = torch.full((b_size,), target_label, dtype=torch.long, device=device)
      f = modelFE(image_mnistm)
      y = modelDC(f)
      y = y.reshape(y.shape[:2])
      err2 = criterionDC(y, label)
      err = err1 + err2
      err.backward()
      optimizerDC.step()
      
      DC_loss_1 += err.item() / len(label) / 2
      
      #######################################################
      # Step 4: train the feature extractor with target data
      # so that the domain classifier unwillingly classifies the feature vector was
      # generated from the source data even though actually generated from the target data
      modelFE.zero_grad()
      modelDC.zero_grad()
      b_size = label_mnistm.shape[0]
      label = torch.full((b_size,), source_label, dtype=torch.long, device=device)
      f = modelFE(image_mnistm)
      y = modelDC(f)
      y = y.reshape(y.shape[:2])
      err = criterionDC(y, label)
      err.backward()
      optimizerFE.step()
      
      DC_loss_2 += err.item() / len(label)
      
    
    LP_train_losses.append(LP_loss / len(mnistm_train_dataset))
    DC_train_losses_1.append(DC_loss_1 / len(mnistm_train_dataset))
    DC_train_losses_2.append(DC_loss_2 / len(mnistm_train_dataset))
    
    # update the learning rate
    schedulerLP.step()
    schedulerFE.step()
    schedulerDC.step()
    
    
    # evaluation
    modelFE.eval()
    modelLP.eval()
    LP_loss = 0
    labels_predicted = None
    labels_true = None
    for i, data in enumerate(mnist_eval_dataloader, 0):
      
      image = data['image'].to(device)
      label = data['label'].to(device)
      
      f = modelFE(image)
      y = modelLP(f)
      y = y.reshape(y.shape[:2])
      err = criterionLP(y, label)
            
      LP_loss += err.item() / len(label)
      
      pred = torch.argmax(y, dim=1).detach().cpu().numpy()
      label = label.detach().cpu().numpy()
      if labels_true is None:
        labels_true = label.copy()
        labels_predicted = pred.copy()
      else:
        labels_true = numpy.hstack((labels_true, label))
        labels_predicted = numpy.hstack((labels_predicted, pred))
      
      if i % 100 == 0:
        print(err.item())
    
    # plot the confusion matrix and save the figure    
    saveConfusionMatrix(mnist_cnf_fig, labels_true, labels_predicted)
    
    LP_eval_losses_1.append(LP_loss / len(mnist_eval_dataset))
    
    LP_loss = 0
    labels_predicted = None
    labels_true = None
    for i, data in enumerate(mnistm_eval_dataloader, 0):
      
      image = data['image'].to(device)
      label = data['label'].to(device)
      
      f = modelFE(image)
      y = modelLP(f)
      y = y.reshape(y.shape[:2])
      err = criterionLP(y, label)
            
      LP_loss += err.item() / len(label)
      
      pred = torch.argmax(y, dim=1).detach().cpu().numpy()
      label = label.detach().cpu().numpy()
      if labels_true is None:
        labels_true = label.copy()
        labels_predicted = pred.copy()
      else:
        labels_true = numpy.hstack((labels_true, label))
        labels_predicted = numpy.hstack((labels_predicted, pred))
      
      if i % 100 == 0:
        print(err.item())

    # plot the confusion matrix and save the figure    
    saveConfusionMatrix(mnistm_cnf_fig, labels_true, labels_predicted)

    LP_eval_losses_2.append(LP_loss / len(mnistm_eval_dataset))
    
    pylab.ion()
    fig = pylab.figure(0)
    fig.clf()
    pylab.plot(LP_train_losses, label='LP train loss')
    pylab.plot(LP_eval_losses_1, label='LP eval loss 1')
    pylab.plot(LP_eval_losses_2, label='LP eval loss 2')
    pylab.plot(DC_train_losses_1, label='DC train loss 1')
    pylab.plot(DC_train_losses_2, label='DC train loss 2')
    pylab.legend()
    pylab.grid()
    pylab.show(block=False)
    pylab.pause(1)
    pylab.savefig(loss_curve_fig)
  
if __name__ == '__main__':
  main()
  
  


![picture](https://drive.google.com/uc?id=1BGehE1TChCp8gPvSexVZU-Y-BOrICcGl)


![picture](https://drive.google.com/uc?id=12xdvmS-HpgtCCcWHWZ-IK2IK4IW9Yv8Q)
