<a href="https://colab.research.google.com/github/purbayankar/TransformerCLSWGAN/blob/main/CLSWGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import torch
from torch.autograd import Variable
from torch import autograd
import time


def calc_gradient_penalty(opt, netD, res_real, res_fake, att):
    alpha = torch.rand(opt.length, 1)
    alpha = alpha.expand(res_real.size())
    if opt.cuda:
        alpha = alpha.cuda()
    interpolates = alpha * res_real + ((1 - alpha) * res_fake)
    interpolates = Variable(interpolates, requires_grad=True)
    disc_interpolates = netD(interpolates, Variable(att))
    ones = torch.ones(disc_interpolates.size())
    if opt.cuda:
        ones = ones.cuda()
    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=ones,
                              create_graph=True, retain_graph=True, only_inputs=True)[0]
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * opt.lambda1
    return gradient_penalty


def first_stage_train(opt, models, loaders):
    one = torch.tensor(1.)
    mone = torch.tensor(-1.)
    if opt.cuda:
        one = one.cuda()
        mone = mone.cuda()
    for i in range(opt.first_epoch):
        epoch_start_time = time.time()
        print('The {}th epoch starts.'.format(i + 1))
        for res_real, res_real_class, res_real_att in loaders.img_loader_train:
            opt.length = res_real_class.shape[0]
            # Train the discriminator.
            for param in models.netD.parameters():
                param.requires_grad = True
            models.netD.zero_grad()
            noise = torch.FloatTensor(opt.length, opt.nz_size)
            noise.normal_()
            if opt.cuda:
                res_real_class = res_real_class.cuda()
                res_real = res_real.cuda()
                noise = noise.cuda()
                res_real_att = res_real_att.cuda()
            att_dis = Variable(res_real_att, requires_grad=True)
            noise_dis = Variable(noise)
            res_real_dis = Variable(res_real, requires_grad=True)

            dis_real = models.netD(res_real_dis, att_dis)
            dis_real_mean = dis_real.mean()
            dis_real_mean.backward(mone)

            res_fake = models.netG(noise_dis, att_dis)
            dis_fake = models.netD(res_fake, att_dis)
            dis_fake_mean = dis_fake.mean()
            dis_fake_mean.backward(one)

            gradient_penalty = calc_gradient_penalty(opt=opt, netD=models.netD, res_real=res_real, res_fake=res_fake,
                                                     att=att_dis)
            gradient_penalty.backward()
            models.optimizerD.step()

            if i % 5 == 0:
                # Train the generator.
                for param in models.netD.parameters():
                    param.requires_grad = False

                models.netG.zero_grad()

                noise.normal_()
                noise_gen = Variable(noise)
                att_gen = Variable(res_real_att, requires_grad=True)
                res_fake = models.netG(noise_gen, att_gen)
                dis_fake = models.netD(res_fake, att_gen)
                dis_fake_mean = - dis_fake.mean()
                cls_result = models.cls(res_fake)
                cls_loss = models.cls_criterion(cls_result, res_real_class.squeeze_())

                gen_loss = dis_fake_mean + opt.cls_weight * cls_loss
                gen_loss.backward()
                models.optimizerG.step()

        print('This epoch use {} mins {} secs'.format(int((time.time() - epoch_start_time) / 60),
                                                      int((time.time() - epoch_start_time) % 60)))


def second_stage_train(opt, models, loaders):
    for param in models.netG.parameters():
        param.requires_grad = False

    for i in range(opt.second_epoch):
        epoch_start_time = time.time()
        print('The {}th epoch starts.'.format(i + 1))
        correct_num = 0
        complete_num = 0
        for _, res_real_class, res_real_att in loaders.img_loader_test:
            opt.len_index = res_real_class.shape[0]
            models.cls.zero_grad()
            noise = torch.FloatTensor(opt.len_index, opt.nz_size)
            noise.normal_()
            if opt.cuda:
                res_real_class = res_real_class.cuda()
                noise = noise.cuda()
                res_real_att = res_real_att.cuda()
            res_real_att = Variable(res_real_att, requires_grad=True)
            res_fake = models.netG(noise, res_real_att)
            cls_result = models.cls(res_fake)
            cls_loss = models.cls_criterion(cls_result, res_real_class.squeeze_())
            cls_loss.backward()
            models.optimizerC.step()

            pred = cls_result.data.max(1)[1]
            correct_num += (pred == res_real_class).sum()
            complete_num += res_real_class.shape[0]
        acc = float(correct_num) / float(complete_num)
        print('Post-Training Acc = {}'.format(acc))
        print('-----------------------------------------------------')
        print('This epoch use {} mins {} secs'.format(int((time.time() - epoch_start_time) / 60),
                                                      int((time.time() - epoch_start_time) % 60)))
    print('Final Post-Training Acc = {}'.format(acc))


def final_test(opt, models, loaders):
    correct_num = 0
    complete_num = 0
    for res_real, res_real_class, res_real_att in loaders.img_loader_test:
        opt.len_index = res_real_class.shape[1]
        if opt.cuda:
            res_real = res_real.cuda()
            res_real_class = res_real_class.cuda()
        cls_result = models.cls(res_real)
        pred = cls_result.data.max(1)[1]
        correct_num += (pred == res_real_class.squeeze_()).sum()
        complete_num += res_real_class.shape[0]
    acc = float(correct_num) / float(complete_num)
    print('Final Acc = {}'.format(acc))

In [3]:
import torch


class init_args:
    def __init__(self):
        self.lr = 0.0001
        self.lr_c = 0.001
        self.beta1 = 0.5
        self.res_size = 2048
        self.nz_size = 312
        self.nz_res_ratio = 1.
        self.att_num = 312
        self.class_num = 200
        self.pre_train_epoch = 100
        self.first_epoch = 1000
        self.second_epoch = 400
        self.cuda = torch.cuda.is_available()
        self.res_path = '/content/drive/MyDrive/xlsa17/data/AWA1/res101.mat'
        self.att_path = '/content/drive/MyDrive/xlsa17/data/AWA1/att_splits.mat'
        self.shuffle = True
        self.batch_size = 2000
        self.lambda1 = 10
        self.cls_weight = 1
        self.length = 0

In [4]:
from torch import nn
from torch.autograd import Variable


class classifier(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(classifier, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)
        self.lsm = nn.LogSoftmax(dim=1)

    def forward(self, inputs):
        output = self.fc(inputs)
        output = self.lsm(output)
        return output



def pre_train(opt, models, loaders):
    for i in range(opt.pre_train_epoch):
        correct_num = 0
        complete_num = 0
        image_class: object
        for res_real, res_real_class, _ in loaders.img_loader_train:
            models.cls.zero_grad()
            if opt.cuda:
                res_real = res_real.cuda()
                res_real_class = res_real_class.cuda()
            res_real = Variable(res_real, requires_grad=True)
            cls_result = models.cls(res_real)
            cls_loss = models.cls_criterion(cls_result, res_real_class.squeeze_())
            cls_loss.backward()
            models.optimizerC.step()
            pred = cls_result.data.max(1)[1]
            correct_num += (pred == res_real_class).sum()
            complete_num += res_real_class.shape[0]
        acc = float(correct_num) / float(complete_num)
        print('Pre-Training Acc = {}'.format(acc))
        print('-----------------------------------------------------')
    print('Final Pre-Training Acc = {}'.format(acc))

In [5]:
!pip install axial_attention
import torch.nn as nn
import torch
import torch.nn.functional as F
from axial_attention import AxialAttention


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        m.weight.data.normal_(0.0, 0.02)
        m.bias.data.fill_(0)


class Discriminator(nn.Module):
    def __init__(self, opt):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(opt.res_size + opt.att_num, 1024)
        self.attn = AxialAttention(
            dim = 1024,
            heads = 8,
            dim_index = 1,
            num_dimensions = 1
        )
        self.fc2 = nn.Linear(1024, 1)
        self.lrelu = nn.LeakyReLU(0.2, True)

        # self.apply(weights_init)

    def forward(self, x, att):
        h = torch.cat((x, att), 1)
        h = self.lrelu(self.fc1(h))
        h = torch.unsqueeze(h,2)
        h = self.attn(h)
        h = torch.squeeze(h)
        h = self.fc2(h)
        return h

# class Discriminator(nn.Module):
#     def __init__(self, opt):
#         super(Discriminator, self).__init__()
#         self.fc1 = nn.Linear(opt.res_size + opt.att_num, 1024)
#         self.fc2 = nn.Linear(1024, 1024)
#         self.fc3 = nn.Linear(1024, 1024)
#         self.fc4 = nn.Linear(1024, 1)
#         self.lrelu = nn.LeakyReLU(0.2, True)
#         self.apply(weights_init)

#     def forward(self, x, att):
#         h = torch.cat((x, att), 1)
#         h = self.lrelu(self.fc1(h))
#         h = self.lrelu(self.fc2(h))
#         h = self.lrelu(self.fc3(h))
#         h = self.fc4(h)
#         return h

# class Discriminator(nn.Module):
#     def __init__(self, opt):
#         super(Discriminator, self).__init__()
#         self.fc1 = nn.Linear(opt.res_size + opt.att_num, 1024)
#         self.fc2 = nn.Linear(1024, 1)
#         self.fc_skip = nn.Linear(opt.att_num, 1024)
#         self.lrelu = nn.LeakyReLU(0.2, True)
#         self.sigmoid = nn.Sigmoid()

#         self.apply(weights_init)

#     def forward(self, x, att):
#         h = torch.cat((x, att), 1)
#         h = self.lrelu(self.fc1(h))
#         h2 = self.lrelu(self.fc_skip(att))
#         h = self.sigmoid(self.fc2(h+h2))
#         return h


class Generator(nn.Module):
    def __init__(self, opt):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(opt.att_num + opt.nz_size, 4096)
        self.attn = AxialAttention(
            dim = 4096,
            heads = 8,
            dim_index = 1,
            num_dimensions = 1
        )
        self.fc2 = nn.Linear(4096, opt.res_size)
        self.lrelu = nn.LeakyReLU(0.2, True)
        self.relu = nn.ReLU(True)

        # self.apply(weights_init)

    def forward(self, noise, att):
        h = torch.cat((noise, att), 1)
        h = self.lrelu(self.fc1(h))
        h = torch.unsqueeze(h,2)
        h = self.attn(h)
        h = torch.squeeze(h)
        h = self.relu(self.fc2(h))
        return h


# class Generator(nn.Module):
#     def __init__(self, opt):
#         super(Generator, self).__init__()
#         self.fc1 = nn.Linear(opt.att_num + opt.nz_size, 4096)
#         self.fc2 = nn.Linear(4096, 4096)
#         self.fc3 = nn.Linear(4096, 4096)
#         self.fc4 = nn.Linear(4096, opt.res_size)
#         self.lrelu = nn.LeakyReLU(0.2, True)
#         #self.prelu = nn.PReLU()
#         self.relu = nn.ReLU(True)

#         self.apply(weights_init)

#     def forward(self, noise, att):
#         h = torch.cat((noise, att), 1)
#         h = self.lrelu(self.fc1(h))
#         h = self.lrelu(self.fc2(h))
#         h = self.lrelu(self.fc3(h))
#         h = self.relu(self.fc4(h))
#         return h

# class Generator(nn.Module):
#     def __init__(self, opt):
#         super(Generator, self).__init__()
#         self.fc1 = nn.Linear(opt.att_num + opt.nz_size, 4096)
#         #self.fc2 = nn.Linear(opt.ngh, opt.ngh)
#         #self.fc2 = nn.Linear(opt.ngh, 1024)
#         self.fc2 = nn.Linear(4096, opt.res_size)
#         self.fc_skip = nn.Linear(opt.att_num, opt.res_size)
#         self.lrelu = nn.LeakyReLU(0.2, True)
#         #self.prelu = nn.PReLU()
#         self.relu = nn.ReLU(True)

#         self.apply(weights_init)

#     def forward(self, noise, att):
#         h = torch.cat((noise, att), 1)
#         h = self.lrelu(self.fc1(h))
#         #h = self.lrelu(self.fc2(h))
#         h = self.relu(self.fc2(h))
#         h2 = self.fc_skip(att)
#         return h+h2

Collecting axial_attention
  Downloading axial_attention-0.6.1-py3-none-any.whl.metadata (560 bytes)
Downloading axial_attention-0.6.1-py3-none-any.whl (6.0 kB)
Installing collected packages: axial_attention
Successfully installed axial_attention-0.6.1


In [6]:
import torch
# import GAN
import torch.optim as optim
from torch.utils.data.dataloader import DataLoader
import scipy.io as sio
# import train_classifier


class init_models:
    def __init__(self, opt):
        self.cls = classifier(opt.res_size, opt.class_num)
        self.netD = Discriminator(opt=opt)
        self.netG = Generator(opt=opt)
        self.optimizerD = optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
        self.optimizerG = optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
        self.optimizerC = optim.Adam(self.cls.parameters(), lr=opt.lr_c, betas=(opt.beta1, 0.999))
        self.cls_criterion = torch.nn.NLLLoss()
        if opt.cuda:
            self.cls.cuda()
            self.netD.cuda()
            self.netG.cuda()
            self.cls_criterion.cuda()


class images_set_train:
    def __init__(self, opt):
        att_origin = sio.loadmat(opt.att_path)
        res_origin = sio.loadmat(opt.res_path)
        att = att_origin['att']
        res = res_origin['features']
        label = res_origin['labels'] - 1
        loc = att_origin['trainval_loc'].squeeze() - 1

        self.label = torch.from_numpy(label[loc]).long()
        self.res = torch.from_numpy(res[:, loc]).float().T
        F.normalize(self.res,p=2,dim=1)
        self.att = torch.from_numpy(att).float().T
        F.normalize(self.att,p=2,dim=1)
        opt.res_size = self.res.shape[1]
        opt.att_num = self.att.shape[1]
        opt.class_num = self.att.shape[0]
        opt.nz_size = int(opt.res_size * opt.nz_res_ratio)

    def __getitem__(self, index):
        """
        :param index: the index of the res feature.
        :return: res, label, att.
        """
        return self.res[index, :], self.label[index], self.att[self.label[index][0]]

    def __len__(self):
        return self.label.shape[0]


class images_set_test:
    def __init__(self, opt):
        att_origin = sio.loadmat(opt.att_path)
        res_origin = sio.loadmat(opt.res_path)
        att = att_origin['att']
        res = res_origin['features']
        label = res_origin['labels'] - 1
        loc = att_origin['test_unseen_loc'].squeeze() - 1

        self.label = torch.from_numpy(label[loc]).long()
        self.res = torch.from_numpy(res[:, loc]).float().T
        F.normalize(self.res,p=2,dim=1)
        self.att = torch.from_numpy(att).float().T
        F.normalize(self.att,p=2,dim=1)


    def __getitem__(self, index):
        """
        :param index: the index of the res feature.
        :return: res, label, att.
        """
        return self.res[index, :], self.label[index], self.att[self.label[index][0]]

    def __len__(self):
        return self.label.shape[0]


class loaders:
    def __init__(self, opt):
        img_data_train = images_set_train(opt)
        img_data_test = images_set_test(opt)
        self.img_loader_train = DataLoader(img_data_train, batch_size=opt.batch_size, shuffle=opt.shuffle)
        self.img_loader_test = DataLoader(img_data_test, batch_size=opt.batch_size, shuffle=opt.shuffle)

In [7]:
print('Preparing Args')

opt = init_args()

print('Preparing Loaders')

loaders = loaders(opt=opt)

print('Preparing Models')

models = init_models(opt=opt)

print('Start Pre-Training')

pre_train(opt=opt, models=models, loaders=loaders)

print('Start First Training Stage')

first_stage_train(opt=opt, models=models, loaders=loaders)

second_stage_train(opt=opt, models=models, loaders=loaders)

final_test(opt=opt, models=models, loaders=loaders)

print('Done!')

Preparing Args
Preparing Loaders
Preparing Models
Start Pre-Training
Pre-Training Acc = 0.5295986284792255
-----------------------------------------------------
Pre-Training Acc = 0.8662767244856797
-----------------------------------------------------
Pre-Training Acc = 0.902027027027027
-----------------------------------------------------
Pre-Training Acc = 0.9166498588140379
-----------------------------------------------------
Pre-Training Acc = 0.9269362646228317
-----------------------------------------------------
Pre-Training Acc = 0.9367688584106495
-----------------------------------------------------
Pre-Training Acc = 0.9415590964098427
-----------------------------------------------------
Pre-Training Acc = 0.9477611940298507
-----------------------------------------------------
Pre-Training Acc = 0.9520976200080677
-----------------------------------------------------
Pre-Training Acc = 0.9551734570391287
-----------------------------------------------------
Pre-Training