ⓒ 2022 CCNets, Inc

https://ccnets.org


Authors : JinSu Kim, Jisu Hong, Yusang Park, Xiyana Figuera, JunHo Park

Initialization

In [None]:
%pip install torch
%pip install torchvision
%pip install pillow
%pip install matplotlib

In [None]:
from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torch.nn.functional as F

import numpy as np

import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from IPython import display
import time

#CCNets init
# Convolutional nerual networks with resents having 10 layers
# from resnets_ccn import cnn_ResNet10 as cnn 
# Transpose convolutional nerual networks with resents having 9 layers
# from resnets_ccn import transpose_cnn_ResNet10 as transpose_cnn


Inputs

In [None]:
#TODO: 
MODEL_PATH = ''

# Root directory for dataset
dataroot = "/content"
    
# Number of workers for dataloader
workers = 2

# Batch size during training
batch_size = 32

# Number of channels in the training images. For color images this is 3
n_img_ch = 3
n_img_sz = 128

# number of dimensions of causal explanation vector in explantory space  
dim_explanation = 256

# A dimension of labels 
dim_label = 2

# Learning rate for Celeb A images
lr = 0.0002 

# Number of training epochs
num_epochs = 100

# Learning rate for optimizers
step_size = 4

# coefficients used for computing running averages of gradient
beta1 = 0.9

# scheduler decay rate
gamma = 0.99954


ngpu = 2

# Set random seed for reproducibility
manualSeed = 999

# ManualSeed = random.randint(1, 10000) # use if you want new results
random.seed(manualSeed)
torch.manual_seed(manualSeed)

# Decide which device we want to run on
device = torch.device('cuda:0' if (torch.cuda.is_available() and ngpu > 0) else "cpu")

In [None]:
# Causal cooperative nets(CCNs) composed of an Explainer, a Reasoner, and a Procuder 
'''-ResNet in PyTorch.

For Pre-activation ResNet, see 'preact_resnet.py'.

Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
'''
class Explainer(nn.Module):
    def __init__(self, Net):
        super(Explainer, self).__init__()
        self.net = Net(n_img_ch, dim_explanation)
        self.sigmoid = nn.Sigmoid()
        self.n_img_ch = n_img_ch
        
    def forward(self, src):
        x = self.net(src)
        x = self.sigmoid(x)
        return x

class Reasoner(nn.Module):
    def __init__(self, Net, n_img_sz, n_img_ch, n_label, n_explain):
        super(Reasoner, self).__init__()
        self.net = Net(n_img_ch + 1, n_label)
        self.sigmoid = nn.Sigmoid()
        self.n_explain = n_explain
        self.n_img_ch = n_img_ch 
        self.n_img_sz = n_img_sz 
        self.n_img_sz_2 = n_img_sz*n_img_sz 

    def forward(self, image, explain):
        e = explain.view([-1, self.n_explain, 1])
        e = e.repeat(1, 1, int(self.n_img_sz_2/self.n_explain))
        reshaped_e = torch.reshape(e, (-1, 1, self.n_img_sz, self.n_img_sz))  
        cat = torch.cat([image, reshaped_e], 1)
        y = self.net(cat)
        y = self.sigmoid(y)
        return y
    
class Producer(nn.Module):
    def __init__(self, Net, n_label, n_explain, n_img_ch):
        super(Producer, self).__init__()
        self.max_n = max(n_label, n_explain)
        self.net = Net(2*self.max_n, n_img_ch)
        self.tanh = nn.Tanh()
        self.n_label = n_label
        self.n_explain = n_explain
        
    def forward(self, label, explain):
        y = label.repeat(1, int(self.max_n/self.n_label))  
        e = explain.repeat(1, int(self.max_n/self.n_explain))  
        cat = torch.cat([y, e], 1)  
        x = self.net(cat)
        x = self.tanh(x)
        return x

Labels

In [None]:
#Characteristics of Celebs

# 5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes Bald    ## 0~4  
# Bangs Big_Lips Big_Nose Black_Hair Blond_Hair                       ## 5~9  
# Blurry Brown_Hair Bushy_Eyebrows Chubby Double_Chin                 ## 10~14  
# Eyeglasses Goatee Gray_Hair Heavy_Makeup High_Cheekbones            ## 15~19  
# Male Mouth_Slightly_Open Mustache Narrow_Eyes No_Beard              ## 20~24  
# Oval_Face Pale_Skin Pointy_Nose Receding_Hairline Rosy_Cheeks       ## 25~29  
# Sideburns Smiling Straight_Hair Wavy_Hair Wearing_Earrings          ## 30~34  
# Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young ## 35~39

In [None]:
'''-ResNet in PyTorch.

For Pre-activation ResNet, see 'preact_resnet.py'.

Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Deep Residual Learning for Image Recognition. arXiv:1512.03385

[2] Modified by PARK, JunHo in April 10, 2022

'''
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, transpose, stride=1):
        super(BasicBlock, self).__init__()
        kernel_size3 = 3
        kernel_size1 = 1
        if transpose is True and stride != 1:
            Conv2d = nn.ConvTranspose2d
            kernel_size3 += 1
            kernel_size1 += 1
        else:
            Conv2d = nn.Conv2d

        self.conv1 = Conv2d(
            in_planes, planes, kernel_size=kernel_size3, stride=stride, padding=1, bias=False)
        
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                Conv2d(in_planes, self.expansion*planes,
                          kernel_size=kernel_size1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, ninp, noutp, transpose):
        super(ResNet, self).__init__()
        self.expansion = 4
        if transpose is False:
            self.in_planes = 64
            self.layer1 = self._make_layer(BasicBlock, 64, transpose)
            self.layer2 = self._make_layer(BasicBlock, 128, transpose)
            self.layer3 = self._make_layer(BasicBlock, 256, transpose)
            self.layer4 = self._make_layer(BasicBlock, 512, transpose)
            self.nlast = 512
        else:
            self.in_planes = 512
            self.layer1 = self._make_layer(BasicBlock, 512, transpose)
            self.layer2 = self._make_layer(BasicBlock, 256, transpose)
            self.layer3 = self._make_layer(BasicBlock, 128, transpose)
            self.layer4 = self._make_layer(BasicBlock, 64, transpose)
            self.nlast = 64
        if transpose is True:
            self.linear = nn.Linear(ninp, 512, bias = False)
            self.conv1 = nn.ConvTranspose2d(512, 512, kernel_size=4, stride=4, padding=0, bias=False)
            self.bn1 = nn.BatchNorm2d(512)
            self.conv2 = nn.ConvTranspose2d(self.nlast*BasicBlock.expansion, noutp, kernel_size=2, stride=2, padding=0, bias=False)
        else:
            self.conv1 = nn.Conv2d(ninp, 64, kernel_size=2, stride=2, padding=0, bias=False)
            self.bn1 = nn.BatchNorm2d(64)
            self.linear2 = nn.Linear(self.nlast*BasicBlock.expansion, noutp, bias = False)
        
        self.transpose = transpose        

    def _make_layer(self, block, planes, transpose):
        layers = []
        layers.append(block(self.in_planes, planes, transpose, stride = 2))
        self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        if self.transpose is True:
            x = self.linear(x)
            x = x.view([-1, 512, 1, 1])
        out = F.relu(self.bn1(self.conv1(x)))

        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)

        if self.transpose is False:
            out = F.avg_pool2d(out, 4)
            out = out.view(out.size(0), -1)
            out = self.linear2(out)
        else :
            out = self.conv2(out)
        return out

def resNet10(ninp, noutp):
    return ResNet(ninp, noutp, transpose = False)

def transpose_resNet10(ninp, noutp):
    return ResNet(ninp, noutp, transpose = True)


TrainLoader / DataLoader

In [None]:
# CelebA dataset
trainset = dset.CelebA(root=dataroot, split = "train", transform=transforms.Compose([
                            transforms.Resize(n_img_sz),
                            transforms.ToTensor(),
                            transforms.CenterCrop(n_img_sz),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                        ]), download = True)

# Get trainloader
def get_trainloader():
    return torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)


CCNets Casual Learning Models

In [None]:
# Intitialize the CCNs composed of an Explainer, a Reasoner, and a Producer
explainer = Explainer(resNet10).to(device)
reasoner = Reasoner(resNet10, n_img_sz, n_img_ch, dim_label, dim_explanation).to(device)
producer = Producer(transpose_resNet10, dim_label, dim_explanation, n_img_ch).to(device)

Variables for debugging

In [None]:
# Variables related to print images for a debugging purpose
fig_len = 15
n_canvas_col = min(6, batch_size)
n_canvas_row = 4
m_empty_image = np.ones((n_img_ch, n_img_sz, n_img_sz))
m_img_canvas = np.ones((n_img_sz*(n_canvas_row+1), n_img_sz*(n_canvas_col), n_img_ch))

Print images

In [None]:
# Show images while debugging
def create_label_controlled_images(images, labels):
    src = images.to(device)
    label_cpu = labels.type(torch.float).clone()

    with torch.no_grad():
        explains = explainer(src)
    m_img_canvas[:n_img_sz,:n_img_sz*(n_canvas_col)] = np.transpose(vutils.make_grid(images[:n_canvas_col], padding = 0, normalize=True), (1,2,0))

    for i in range(n_canvas_row):
        controlled_labels_cpu = label_cpu.clone()
        for j in range(n_canvas_col):
            if i == 0:
                controlled_labels_cpu[j][0] = 0.
                controlled_labels_cpu[j][1] = 1.
            elif i == 1:
                controlled_labels_cpu[j][0] = 0.
                controlled_labels_cpu[j][1] = 0.
            elif i == 2:
                controlled_labels_cpu[j][0] = 1.
                controlled_labels_cpu[j][1] = 1.
            elif i == 3:
                controlled_labels_cpu[j][0] = 1.
                controlled_labels_cpu[j][1] = 0.
        controlled_labels = controlled_labels_cpu.to(device)
        with torch.no_grad():
            generated_images = producer(controlled_labels, explains).cpu()
        generated_images = generated_images[:n_canvas_col]
        m_img_canvas[n_img_sz*(i+1):n_img_sz*(i+2), :n_img_sz*(n_canvas_col)] = np.transpose(vutils.make_grid(generated_images, padding = 0,normalize=True), (1,2,0))
    return m_img_canvas

In [None]:
def print_figure(show_image):
    plt.figure(figsize=(fig_len, fig_len))
    plt.subplot(1, 1, 1)
    plt.imshow(show_image, interpolation="sinc")
    plt.axis("off")
    plt.show()

Optimizers & Scheduler

In [None]:
# Init of optimizers and scheduler for the explainer, reasoner, and producer
momentum=0.9

opt_explainer = optim.Adam(explainer.parameters(), lr=lr, betas=(beta1, 0.999))
opt_reasoner = optim.Adam(reasoner.parameters(), lr=lr, betas=(beta1, 0.999))
opt_producer = optim.Adam(producer.parameters(), lr=lr, betas=(beta1, 0.999))

scheduler_explainer = optim.lr_scheduler.StepLR(opt_explainer, step_size=step_size, gamma=gamma)
scheduler_reasoner  = optim.lr_scheduler.StepLR(opt_reasoner, step_size=step_size, gamma=gamma)
scheduler_producer  = optim.lr_scheduler.StepLR(opt_producer, step_size=step_size, gamma=gamma)


In [None]:
# Update model parameters
def step_scheduler():
    scheduler_explainer.step() 
    scheduler_reasoner.step()  
    scheduler_producer.step() 

Save models

In [None]:
def save_models(model_path = MODEL_PATH):
    torch.save(explainer.state_dict(),os.path.join(model_path, 'explainer.pth'))
    torch.save(reasoner.state_dict(),os.path.join(model_path, 'reasoner.pth'))
    torch.save(producer.state_dict(),os.path.join(model_path, 'producer.pth'))

Set up training

In [None]:
# Set up training for each batch

def requires_grads(require_grad = True):
    explainer.requires_grad_(require_grad)
    reasoner.requires_grad_(require_grad)
    producer.requires_grad_(require_grad)

def zero_grads():
    explainer.zero_grad()
    reasoner.zero_grad()
    producer.zero_grad()

def set_train():
    requires_grads(True)
    zero_grads()

In [None]:
def get_error(loss):
    return loss.mean().detach().cpu().item()

Loss Function 

In [None]:
_loss_L1 = nn.L1Loss(reduction='none')
_loss_L1_mean = nn.L1Loss(reduction='mean')

# Calculate a loss value by L1 loss with mean reduction, and the target parameter detached(O)
def loss_function(prediction, target):
    return _loss_L1_mean(prediction, target.detach())

# Calculate a loss value by L1 loss with no reduction, and the target parameter detached(O)
def error_function(prediction, target):
    return _loss_L1(prediction, target.detach())


Backwards 

Causal Training with a labeled dataset

In [None]:
# Causal Training(training mode A) of Causal Cooperative Nets(CCNs)

# Source and target are a batch of images and labels
def causal_training(source, target):

# Observation is the source in GPU
    observation = source.to(device)
# Label is the label in GPU
    label = target.type(torch.float).to(device)

# Set up training for a batch
    set_train()

# Explainer receives observation as an input and outputs causal explanation vector
    causal_explanation = explainer(observation)
    
# Reasoner receives observation and causal explanation as inputs and outputs inferred labels
    inferred_label = reasoner(observation, causal_explanation)
    ###################################################

# Producer receives a label and a causal explanation as inputs and outputs generated observation
    generated_observation = producer(label, causal_explanation)

# Producer receives an inferred label and a causal explanation as inputs and outputs reconstructed observation
# Causal explanation inputs to the producer is detached from the backward pass
    reconstructed_observation = producer(inferred_label, causal_explanation.detach())
    ###################################################

# A set of prediction losses which are an inference loss, generation loss, and reconstruction loss 
# Calcuated by prdiciton from output obervations and the intput obervations
# Batch reduction is set to mean when calculating the prediction losses
    inference_loss = loss_function(reconstructed_observation, generated_observation)
    generation_loss = loss_function(generated_observation, observation)
    reconstruction_loss = loss_function(reconstructed_observation, observation)
    ###################################################

# Model errors of Explainer, Reasoner, and Producer calculated by a set of prediction losses
    explainer_error = error_function(inference_loss + generation_loss, reconstruction_loss)
    reasoner_error = error_function(reconstruction_loss + inference_loss, generation_loss)
    producer_error = error_function(generation_loss + reconstruction_loss, inference_loss)
    ###################################################

# Compute gradients of the error function with respect to the parameters of Explainer, Reasoner, and Producer respectively.
# Error backpropagations of the model errors(Explainer error, Reasoner error, and producer error)x
# Through the propagtation paths created by the prediciton losses, 
# Which are an inference loss, a generation loss, and a reconstruction loss
    explainer_error.backward(retain_graph = True)
    causal_explanation.detach_()

    reasoner.zero_grad()
    producer.zero_grad()
    reasoner_error.backward(retain_graph = True)
    inferred_label.detach_()

    producer.zero_grad()
    producer_error.backward(retain_graph = False)

# Update model parameters 
    opt_explainer.step()
    opt_reasoner.step()
    opt_producer.step()

# Losses are prediction losses and erorrs are model errors 
    losses = np.array([get_error(inference_loss), get_error(generation_loss), get_error(reconstruction_loss)]) 
    errors = np.array([get_error(explainer_error), get_error(reasoner_error), get_error(producer_error)])

    return losses, errors 

Print losses & errors

In [None]:
from torch.utils.tensorboard import SummaryWriter
logger = SummaryWriter(log_dir=f"./logger/{step_size}")
# Print prediction losses, and model errors while training
def print_training(epoch, num_epochs, idx, len_dataloader, errors, losses):
    print('[%d/%d][%d/%d]'
        % (epoch, num_epochs, idx, len_dataloader))
    print('loss\t Inference: %.4f\tGeneration: %.4f\tReconstruction: %.4f'
        % (losses[0], losses[1], losses[2]))
    print('Explainer: %.4f\tReasoner: %.4f\tProducer: %.4f'
        % (errors[0], errors[1], errors[2]))
    
    logger.add_scalar("Train/Inference", losses[0], epoch*len_dataloader+idx)
    logger.add_scalar("Train/Generation", losses[1], epoch*len_dataloader+idx)
    logger.add_scalar("Train/Reconstruction", losses[2], epoch*len_dataloader+idx)
    logger.add_scalar("Train/Explainer", errors[0], epoch*len_dataloader+idx)
    logger.add_scalar("Train/tReasoner", errors[1], epoch*len_dataloader+idx)
    logger.add_scalar("Train/tProducer", errors[2], epoch*len_dataloader+idx)

    logger.flush()

        

Traning

In [None]:
print("Starting Training Loop...")

pvt_time = time.time()
cnt_iter = 0  

record_avg_error = []
record_avg_loss = []

n_debug = 100
n_print = 50

m_zeros3 = np.array([0., 0., 0.])

previous_err = np.copy(m_zeros3)
previous_loss = np.copy(m_zeros3)

for epoch in range(num_epochs):
    random.seed(pvt_time)
    torch.manual_seed(pvt_time)

    dataloader = get_trainloader()
    len_dataloader = len(dataloader)

    m_err_sum = np.zeros_like(np.arange(3), dtype=np.float64)
    m_loss_sum = np.zeros_like(np.arange(3), dtype=np.float64)

    for i, (image_batch, label_batch) in enumerate(dataloader):
        indices = torch.tensor([20, 31]) # Male, Smiling
        label_batch = torch.index_select(label_batch, 1, indices)
        err, loss = causal_training(image_batch, label_batch)

        m_loss_sum += loss
        m_err_sum += err

        if i % n_debug == 0:
            display.clear_output(wait=True)
            show_image = create_label_controlled_images(image_batch, label_batch)
            print_figure(show_image)
            
# print the current learning rate of the Explainer, Reasoner, and Producer respectively
            print('Epoch-{0} lr_E: {1} lr_R: {2} lr_P: {3}'.format(epoch, \
                opt_explainer.param_groups[0]['lr'], opt_reasoner.param_groups[0]['lr'], opt_producer.param_groups[0]['lr']))
            print_cnt = 0

# For each n_print the iteration, print the training info.
        if i % n_print == 0:
            err_avg = m_err_sum/(i+1)
            loss_avg = m_loss_sum/(i+1)

            print_training(epoch, num_epochs, i, len_dataloader, err_avg, loss_avg)
             
            cur_time = time.time()
            print ('Time for epoch {} is {} sec'.format(epoch + 1, cur_time - pvt_time))
            pvt_time = cur_time
            print_cnt += 1 
                    
            cnt_iter = 0 
            
        step_scheduler()   

In [None]:
print_figure(show_image)

In [None]:
# Save model files in the local directory path "MODEL_PATH"
save_models(MODEL_PATH)