In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!mkdir -p /content/drive/MyDrive/ARShadowGAN_samples
!mkdir -p /content/drive/MyDrive/ARShadowGAN_checkpoints


In [None]:
# Path to your zipped dataset in Google Drive
zip_file_path = '/content/drive/MyDrive/colab_data/shadow_ar_dataset.zip' # Adjust if your path is different

# Destination directory within Google Drive to unzip to
# This will create a new folder (e.g., 'shadow_ar_dataset') inside your colab_data folder
destination_path = '/content/colab_data/shadow_ar_dataset'

# Create the destination directory if it doesn't exist (optional, unzip usually creates it)
import os
os.makedirs(destination_path, exist_ok=True)

# Unzip the file
!unzip -q {zip_file_path} -d {destination_path}
print(f"Dataset unzipped to: {destination_path}")

Dataset unzipped to: /content/colab_data/shadow_ar_dataset


In [None]:
import os

base_path = "/content/colab_data/shadow_ar_dataset/dataset/train"

for folder in ["mask", "noshadow", "robject", "rshadow", "shadow"]:
    folder_path = os.path.join(base_path, folder)
    print(f"{folder}: {len(os.listdir(folder_path))} files")

mask: 2552 files
noshadow: 2552 files
robject: 2552 files
rshadow: 2552 files
shadow: 2552 files


In [None]:
pip install numpy==1.24.4 --force-reinstall


Collecting numpy==1.24.4
  Downloading numpy-1.24.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.6 kB)
Downloading numpy-1.24.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.3/17.3 MB[0m [31m84.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 2.0.2
    Uninstalling numpy-2.0.2:
      Successfully uninstalled numpy-2.0.2
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
blosc2 3.6.1 requires numpy>=1.26, but you have numpy 1.24.4 which is incompatible.
xarray-einstats 0.9.1 requires numpy>=1.25, but you have numpy 1.24.4 which is incompatible.
jaxlib 0.5.1 requires numpy>=1.25, but you have numpy 1.24.4 which is incompatible.
pymc 5.25.1 requires numpy

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
from tqdm import tqdm
import torch.nn.functional as F

# ==========================
# Dataset Loader
# ==========================
class ShadowDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.noshadow_dir = os.path.join(root_dir, 'noshadow')
        self.mask_dir = os.path.join(root_dir, 'mask')
        self.robject_dir = os.path.join(root_dir, 'robject')
        self.rshadow_dir = os.path.join(root_dir, 'rshadow')
        self.shadow_dir = os.path.join(root_dir, 'shadow')
        self.transform = transform
        self.filenames = os.listdir(self.noshadow_dir)

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        fname = self.filenames[idx]

        def load_img(folder, mode='RGB'):
            path = os.path.join(folder, fname)
            image = Image.open(path).convert(mode)
            return self.transform(image)

        noshadow = load_img(self.noshadow_dir, 'RGB')
        mask = load_img(self.mask_dir, 'L')
        robject = load_img(self.robject_dir, 'RGB')
        rshadow = load_img(self.rshadow_dir, 'RGB')
        shadow = load_img(self.shadow_dir, 'RGB')

        # Concatenate: 3 + 1 + 3 + 3 = 10 channels
        return torch.cat([noshadow, mask, robject, rshadow], dim=0), shadow

# ==========================
# UNet Generator
# ==========================
class UNetGenerator(nn.Module):
    def __init__(self, in_channels=10, out_channels=3, features=64):
        super().__init__()

        def down_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 4, stride=2, padding=1),
                nn.BatchNorm2d(out_c),
                nn.LeakyReLU(0.2)
            )

        def up_block(in_c, out_c):
            return nn.Sequential(
                nn.ConvTranspose2d(in_c, out_c, 4, stride=2, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU()
            )

        self.down1 = down_block(in_channels, features)
        self.down2 = down_block(features, features * 2)
        self.down3 = down_block(features * 2, features * 4)
        self.down4 = down_block(features * 4, features * 8)

        self.bottleneck = nn.Sequential(
            nn.Conv2d(features * 8, features * 8, 4, stride=2, padding=1),
            nn.ReLU()
        )

        self.up1 = up_block(features * 8, features * 8)
        self.up2 = up_block(features * 8 * 2, features * 4)
        self.up3 = up_block(features * 4 * 2, features * 2)
        self.up4 = up_block(features * 2 * 2, features)

        self.final = nn.Sequential(
            nn.ConvTranspose2d(features * 2, out_channels, 4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)

        bottleneck = self.bottleneck(d4)

        up1 = self.up1(bottleneck)
        up2 = self.up2(torch.cat([up1, d4], dim=1))
        up3 = self.up3(torch.cat([up2, d3], dim=1))
        up4 = self.up4(torch.cat([up3, d2], dim=1))
        return self.final(torch.cat([up4, d1], dim=1))

# ==========================
# Discriminator
# ==========================
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(13, 64, 3, 1, 1), nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 3, 1, 1), nn.LeakyReLU(0.2),
            nn.Conv2d(128, 1, 3, 1, 1), nn.Sigmoid()
        )

    def forward(self, x, shadow):
        return self.model(torch.cat([x, shadow], dim=1))

# ==========================
# Training Function
# ==========================
def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])

    dataset = ShadowDataset("/content/colab_data/shadow_ar_dataset/dataset/train", transform)
    dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=2)

    generator = UNetGenerator().to(device)
    discriminator = Discriminator().to(device)

    start_epoch = 581  # Resume from 60
    num_epochs = 600  # Train further

    g_path = f"/content/drive/MyDrive/generator_unet_epoch_{start_epoch}.pth"
    d_path = f"/content/drive/MyDrive/discriminator_epoch_{start_epoch}.pth"

    if os.path.exists(g_path):
        generator.load_state_dict(torch.load(g_path))
        print(f"✅ Loaded generator from {g_path}")
    if os.path.exists(d_path):
        discriminator.load_state_dict(torch.load(d_path))
        print(f"✅ Loaded discriminator from {d_path}")

    optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

    adversarial_loss = nn.BCELoss()
    pixelwise_loss = nn.L1Loss()

    save_img_dir = "/content/drive/MyDrive/shadow_outputs"
    os.makedirs(save_img_dir, exist_ok=True)

    for epoch in range(start_epoch, num_epochs):
        g_total, d_total = 0, 0
        loop = tqdm(dataloader, desc=f"Epoch [{epoch+1}/{num_epochs}]")

        for i, (inputs, real_shadow) in enumerate(loop):
            inputs, real_shadow = inputs.to(device), real_shadow.to(device)

            real_label = torch.ones((inputs.size(0), 1, 256, 256), device=device)
            fake_label = torch.zeros_like(real_label)

            # === Train Discriminator ===
            optimizer_D.zero_grad()
            fake_shadow = generator(inputs).detach()
            d_real = discriminator(inputs, real_shadow)
            d_fake = discriminator(inputs, fake_shadow)

            d_loss = 0.5 * (adversarial_loss(d_real, real_label) + adversarial_loss(d_fake, fake_label))
            d_loss.backward()
            optimizer_D.step()

            # === Train Generator ===
            optimizer_G.zero_grad()
            fake_shadow = generator(inputs)
            g_adv = adversarial_loss(discriminator(inputs, fake_shadow), real_label)
            g_l1 = pixelwise_loss(fake_shadow, real_shadow)
            g_loss = g_adv + 100 * g_l1  # Weighted sum
            g_loss.backward()
            optimizer_G.step()

            g_total += g_loss.item()
            d_total += d_loss.item()
            loop.set_postfix(G_Loss=g_total / (i+1), D_Loss=d_total / (i+1))

        # Save model
        torch.save(generator.state_dict(), f"/content/drive/MyDrive/generator_unet_epoch_{epoch+1}.pth")
        torch.save(discriminator.state_dict(), f"/content/drive/MyDrive/discriminator_epoch_{epoch+1}.pth")

        # Save sample images
        save_image(fake_shadow[:4], f"{save_img_dir}/fake_epoch_{epoch+1}.jpg", normalize=True)
        save_image(real_shadow[:4], f"{save_img_dir}/real_epoch_{epoch+1}.jpg", normalize=True)

if __name__ == "__main__":
    train()


✅ Loaded generator from /content/drive/MyDrive/generator_unet_epoch_581.pth
✅ Loaded discriminator from /content/drive/MyDrive/discriminator_epoch_581.pth


Epoch [582/600]: 100%|██████████| 319/319 [02:02<00:00,  2.60it/s, D_Loss=0.334, G_Loss=5.22]
Epoch [583/600]: 100%|██████████| 319/319 [02:04<00:00,  2.56it/s, D_Loss=0.324, G_Loss=5.26]
Epoch [584/600]: 100%|██████████| 319/319 [02:04<00:00,  2.57it/s, D_Loss=0.32, G_Loss=5.37]
Epoch [585/600]: 100%|██████████| 319/319 [02:04<00:00,  2.56it/s, D_Loss=0.305, G_Loss=5.43]
Epoch [586/600]: 100%|██████████| 319/319 [02:03<00:00,  2.58it/s, D_Loss=0.32, G_Loss=5.28]
Epoch [587/600]: 100%|██████████| 319/319 [02:02<00:00,  2.60it/s, D_Loss=0.332, G_Loss=5.23]
Epoch [588/600]: 100%|██████████| 319/319 [02:03<00:00,  2.58it/s, D_Loss=0.316, G_Loss=5.37]
Epoch [589/600]: 100%|██████████| 319/319 [02:04<00:00,  2.56it/s, D_Loss=0.306, G_Loss=5.43]
Epoch [590/600]: 100%|██████████| 319/319 [02:03<00:00,  2.59it/s, D_Loss=0.306, G_Loss=5.3]
Epoch [591/600]: 100%|██████████| 319/319 [02:03<00:00,  2.59it/s, D_Loss=0.365, G_Loss=5.04]
Epoch [592/600]: 100%|██████████| 319/319 [02:02<00:00,  2.60it