Downloading The Dataset

In [0]:
!gdown https://drive.google.com/uc?id=1I0qw-65KBA6np8vIZzO6oeiOvcDBttAY 
!unrar x "/content/ISTD_Dataset.rar" 

Making Dataset      (Image Loader)

In [0]:
import glob
import os
from PIL import Image
import torchvision.transforms as transforms
import torch.utils.data as DATA

In [0]:
def make_dataset():
    dataset = []
    original_img_rpath = '/content/ISTD_Dataset/train/train_A'
    shadow_mask_rpath = '/content/ISTD_Dataset/train/train_B'
    shadow_free_img_rpath = '/content/ISTD_Dataset/train/train_C'
    for img_path in glob.glob(os.path.join(original_img_rpath, '*.png')):
        basename = os.path.basename(img_path)
        original_img_path = os.path.join(original_img_rpath, basename)
        shadow_mask_path = os.path.join(shadow_mask_rpath, basename)
        shadow_free_img_path = os.path.join(shadow_free_img_rpath, basename)
        #print(original_img_path, shadow_mask_path, shadow_free_img_path)
        dataset.append([original_img_path, shadow_mask_path, shadow_free_img_path])
    #print(dataset)
    return dataset



class shadow_triplets_loader(DATA.Dataset):
    def __init__(self):
        super(shadow_triplets_loader, self).__init__()
        self.train_set_path = make_dataset()

    def __getitem__(self, item):
        original_img_path, shadow_mask_path, shadow_free_img_path = self.train_set_path[item]
        transform = transforms.ToTensor()
        #print(original_img_path, shadow_mask_path, shadow_free_img_path)
        original_img = Image.open(original_img_path)
        shadow_mask = Image.open(shadow_mask_path)
        shadow_free_img = Image.open(shadow_free_img_path)

        original_img = transform(original_img.resize((256, 256)))
        shadow_mask = transform(shadow_mask.resize((256, 256)))
        shadow_free_img = transform(shadow_free_img.resize((256, 256)))

        return original_img, shadow_mask, shadow_free_img

    def __len__(self):
        return len(self.train_set_path)

Stcgan_net

In [0]:
import torch
import torch.nn as nn

In [0]:
class Generator_first(nn.Module):
    def  __init__(self):
        super(Generator_first, self).__init__()
        self.conv0 = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),
            nn.LeakyReLU(),
        )
        self.conv1 = nn.Sequential(
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(128, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU()
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(256, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU()
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(512, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU()
        )
        self.conv5 = nn.Sequential(
            nn.Conv2d(512, 512, 3, 1, 1),
            nn.ReLU()
        )
        self.convt6 = nn.Sequential(
            nn.ConvTranspose2d(512, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.convt7 = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.convt8 = nn.Sequential(
            nn.ConvTranspose2d(1024, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.convt9 = nn.Sequential(
            nn.ConvTranspose2d(512, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.convt10 = nn.Sequential(
            nn.ConvTranspose2d(256, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.convt11 = nn.Sequential(
            nn.ConvTranspose2d(128, 1, 3, 1, 1),
            nn.Tanh()
        )

        self._initialize_weights()


    def forward(self, input):
        conv0 = self.conv0(input)
        conv1 = self.conv1(conv0)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)
        conv4 = self.conv4(conv3)
        conv4 = self.conv4(conv4)
        conv4 = self.conv4(conv4)
        conv5 = self.conv5(conv4)
        convt6 = self.convt6(conv5)
        conv6 = torch.cat((conv4, convt6), 1)
        convt7 = self.convt7(conv6)
        conv6 = torch.cat((conv4, convt7), 1)
        convt7 = self.convt7(conv6)
        conv6 = torch.cat((conv4, convt7), 1)
        convt7 = self.convt7(conv6)
        conv7 = torch.cat((conv3, convt7), 1)
        convt8 = self.convt8(conv7)
        conv8 = torch.cat((conv2, convt8), 1)
        convt9 = self.convt9(conv8)
        conv9 = torch.cat((conv1, convt9), 1)
        convt10 = self.convt10(conv9)
        conv10 = torch.cat((conv0, convt10), 1)
        convt11 = self.convt11(conv10)

        return convt11

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.normal_(m.weight, mean=0, std=0.02)
                torch.nn.init.constant_(m.bias, 0.1)


class Generator_second(nn.Module):
    def  __init__(self):
        super(Generator_second, self).__init__()
        self.conv0 = nn.Sequential(
            nn.Conv2d(4, 64, 3, 1, 1),
            nn.LeakyReLU(),
        )
        self.conv1 = nn.Sequential(
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(128, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU()
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(256, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU()
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(512, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU()
        )
        self.conv5 = nn.Sequential(
            nn.Conv2d(512, 512, 3, 1, 1),
            nn.ReLU()
        )
        self.convt6 = nn.Sequential(
            nn.ConvTranspose2d(512, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.convt7 = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.convt8 = nn.Sequential(
            nn.ConvTranspose2d(1024, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.convt9 = nn.Sequential(
            nn.ConvTranspose2d(512, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.convt10 = nn.Sequential(
            nn.ConvTranspose2d(256, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.convt11 = nn.Sequential(
            nn.ConvTranspose2d(128, 3, 3, 1, 1),
            nn.Tanh()
        )
        self._initialize_weights()

    def forward(self, input):
        conv0 = self.conv0(input)
        conv1 = self.conv1(conv0)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)
        conv4 = self.conv4(conv3)
        conv4 = self.conv4(conv4)
        conv4 = self.conv4(conv4)
        conv5 = self.conv5(conv4)
        convt6 = self.convt6(conv5)
        conv6 = torch.cat((conv4, convt6), 1)
        convt7 = self.convt7(conv6)
        conv6 = torch.cat((conv4, convt7), 1)
        convt7 = self.convt7(conv6)
        conv6 = torch.cat((conv4, convt7), 1)
        convt7 = self.convt7(conv6)
        conv7 = torch.cat((conv3, convt7), 1)
        convt8 = self.convt8(conv7)
        conv8 = torch.cat((conv2, convt8), 1)
        convt9 = self.convt9(conv8)
        conv9 = torch.cat((conv1, convt9), 1)
        convt10 = self.convt10(conv9)
        conv10 = torch.cat((conv0, convt10), 1)
        convt11 = self.convt11(conv10)

        return convt11


    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.normal_(m.weight, mean=0, std=0.02)
                torch.nn.init.constant_(m.bias, 0.1)



class Discriminator_first(nn.Module):
    def __init__(self):
        super(Discriminator_first, self).__init__()
        self.feature = nn.Sequential(
            nn.Conv2d(4, 64, 3, 1, 1),
            nn.LeakyReLU(),
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            nn.Conv2d(128, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(),
            nn.Conv2d(256, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(),
            nn.Conv2d(512, 1, 3, 1, 1),
            nn.Sigmoid()
        )
        self._initialize_weights()

    def forward(self, input):
        output = self.feature(input)
        return output

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.normal_(m.weight, mean=0, std=0.02)
                torch.nn.init.constant_(m.bias, 0.1)


class Discriminator_second(nn.Module):
    def __init__(self):
        super(Discriminator_second, self).__init__()
        self.feature = nn.Sequential(
            nn.Conv2d(7, 64, 3, 1, 1),
            nn.LeakyReLU(),
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            nn.Conv2d(128, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(),
            nn.Conv2d(256, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(),
            nn.Conv2d(512, 1, 3, 1, 1),
            nn.Sigmoid()
        )
        self._initialize_weights()

    def forward(self, input):
        output = self.feature(input)
        return output

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.normal_(m.weight, mean=0, std=0.02)
                torch.nn.init.constant_(m.bias, 0.1)

Main 

In [0]:
import torch.utils.data as Data
import os

Check whether GPU is enabled

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [0]:
BATCH_SIZE=1
lambda1 = 5
lambda2 = 0.1
lambda3 = 0.1

In [9]:
THIS_FOLDER = os.getcwd()
print(THIS_FOLDER)
!mkdir "/content/models"

/content


In [0]:
def single_gpu_train():
    dataset = shadow_triplets_loader()
    data_loader = Data.DataLoader(dataset, batch_size=BATCH_SIZE)

    G1 = Generator_first().to(device)
    G2 = Generator_second().to(device)
    D1 = Discriminator_first().to(device)
    D2 = Discriminator_second().to(device)


    criterion1 = torch.nn.BCELoss(size_average=False)
    criterion2 = torch.nn.L1Loss()
    optimizerd = torch.optim.Adam([
        {'params': D1.parameters()},
        {'params': D2.parameters()}], lr=0.001)
    optimizerg = torch.optim.Adam([
        {'params': G1.parameters()},
        {'params': G2.parameters()}], lr=0.001)

    for epoch in range(10):
        for i, data in enumerate(data_loader):
            original_image, shadow_mask, shadow_free_image = data
            original_image = original_image.to(device)
            shadow_mask = shadow_mask.to(device)
            shadow_free_image = shadow_free_image.to(device)

            g1_output = G1(original_image)
            g1 = torch.cat((original_image, g1_output), 1)
            gt1 = torch.cat((original_image, shadow_mask), 1)

            prob_gt1 = D1(gt1).detach()
            prob_g1 = D1(g1)

            #D1_loss = -torch.mean(torch.log(prob_gt1) +  torch.log(1 - prob_g1))
            #G1_loss = torch.mean(torch.log(shadow_mask - g1_output))
            D1_loss = criterion1(prob_g1, prob_gt1)
            G1_loss = criterion2(g1_output, shadow_mask)

            g2_input = torch.cat((original_image, shadow_mask), 1)
            g2_output = G2(g2_input)

            gt2 = torch.cat((original_image, shadow_mask, shadow_free_image), 1)
            g2 = torch.cat((original_image, g1_output, g2_output), 1)

            prob_gt2 = D2(gt2).detach()
            prob_g2 = D2(g2)

            #D2_loss = -torch.mean(torch.log(prob_gt2) + torch.log(1 - prob_g2))
            #G2_loss = torch.mean(torch.log(shadow_free_image, g2_output))
            D2_loss = criterion1(prob_g2, prob_gt2)
            G2_loss = criterion2(g2_output, shadow_free_image)

            loss = G1_loss + lambda1 * G2_loss + lambda2 * D1_loss + lambda3 * D2_loss
            print('Epoch: %d | iter: %d | train loss: %.10f' % (epoch, i, float(loss)))
            if epoch % 2000 < 1000:
                optimizerd.zero_grad()
                loss.backward()
                optimizerd.step()
            else:
                optimizerg.zero_grad()
                loss.backward()
                optimizerg.step()
            

        generator1_model = os.path.join(THIS_FOLDER,"model/generator1_%d.pkl" % epoch)
        generator2_model = os.path.join(THIS_FOLDER,"model/generator2_%d.pkl" % epoch)
        discriminator1_model = os.path.join(THIS_FOLDER,"model/discriminator1_%d.pkl" % epoch)
        discriminator2_model = os.path.join(THIS_FOLDER,"model/discriminator2_%d.pkl" % epoch)
        torch.save(G1.state_dict(), generator1_model)
        torch.save(G2.state_dict(), generator2_model)
        torch.save(D1.state_dict(), discriminator1_model)
        torch.save(D2.state_dict(), discriminator2_model)

In [0]:
single_gpu_train()



[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch: 0 | iter: 93 | train loss: 3637.9711914062
Epoch: 0 | iter: 94 | train loss: 4025.8056640625
Epoch: 0 | iter: 95 | train loss: 3295.4125976562
Epoch: 0 | iter: 96 | train loss: 2807.0349121094
Epoch: 0 | iter: 97 | train loss: 5198.6147460938
Epoch: 0 | iter: 98 | train loss: 2983.0083007812
Epoch: 0 | iter: 99 | train loss: 5476.3896484375
Epoch: 0 | iter: 100 | train loss: 2945.2133789062
Epoch: 0 | iter: 101 | train loss: 3491.6235351562
Epoch: 0 | iter: 102 | train loss: 3640.7768554688
Epoch: 0 | iter: 103 | train loss: 3211.3901367188
Epoch: 0 | iter: 104 | train loss: 2818.6359863281
Epoch: 0 | iter: 105 | train loss: 3562.7590332031
Epoch: 0 | iter: 106 | train loss: 4460.6811523438
Epoch: 0 | iter: 107 | train loss: 3865.8618164062
Epoch: 0 | iter: 108 | train loss: 4514.1591796875
Epoch: 0 | iter: 109 | train loss: 2980.7270507812
Epoch: 0 | iter: 110 | train loss: 3947.5766601562
Epoch: 0 | iter: 111 | t