In [22]:
from pathlib import Path

import torch
from torchvision import transforms
from torchvision.utils import save_image

# Predictor

Here, we will instantiate a standard image pretrained model, and use it as our "oracle" for
performing rejection sampling during generating the CausalCelebA dataset.

In [None]:
class PredictorPipeline(pl.LightningModule):
    def __init__(
        self,
        cared_list=["all"],
        dat_root=None,
        lr=0.01,
        batch_size=100,
    ):
        super().__init__()
        self.lr = lr
        self.batch_size = batch_size
        self.cared_list = cared_list

        model = torchvision.models.resnet18(pretrained=True)
        num_features = model.fc.in_features
        model.fc = nn.Linear(
            num_features, 2 ** len(cared_list)
        )  # multi-class classification (num_of_class == 307)
        self.model = model

        if dat_root:
            self.dat_root = dat_root
        else:
            self.dat_root = "dat/CelebAMask-HQ"
        # model = model.to(device)

    def train_dataloader(self):
        transforms_train = transforms.Compose(
            [
                transforms.Resize((256, 256)),
                transforms.RandomHorizontalFlip(),  # data augmentation
                transforms.ToTensor(),
                transforms.Normalize(
                    [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
                ),  # normalization
            ]
        )
        train_set = CelebAMaskHQDataset(
            root=self.dat_root,
            norm=True,
            transform=transforms_train,
            env="train",
            cared_list=self.cared_list,
        )
        train_loader = torch.utils.data.DataLoader(
            train_set, batch_size=self.batch_size, shuffle=True, drop_last=True
        )
        return train_loader

    def val_dataloader(self):
        transforms_test = transforms.Compose(
            [
                transforms.Resize((256, 256)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]
        )
        test_set = CelebAMaskHQDataset(
            root=self.dat_root,
            norm=True,
            transform=transforms_test,
            env="test",
            cared_list=self.cared_list,
        )
        test_loader = torch.utils.data.DataLoader(
            test_set, batch_size=self.batch_size, shuffle=True, drop_last=True
        )
        return test_loader

    def configure_optimizers(self):
        optimizer = optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9)
        lr_scheduler = ReduceLROnPlateau(
            optimizer,
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": lr_scheduler,
            "monitor": "val_loss",
        }

    def training_step(self, batch, batch_idx):
        x, y = batch

        outputs = self.model(x)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(outputs, y)

        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch

        outputs = self.model(x)

        criterion = nn.CrossEntropyLoss()
        loss = criterion(outputs, y)

        _, preds = torch.max(outputs, 1)
        corrects = torch.sum(preds == y.data)

        self.log("val_loss", loss, on_step=False, on_epoch=True)
        return corrects

    def validation_epoch_end(self, outputs):
        total_corrects = torch.sum(torch.tensor(outputs)) / (
            self.batch_size * len(outputs)
        )
        print(f"acc at epoch {self.current_epoch} is {total_corrects.item()}")

In [10]:
gen_model = torch.hub.load(
    "facebookresearch/pytorch_GAN_zoo:hub",
    "PGAN",
    model_name="celebAHQ-512",
    pretrained=True,
    useGPU=False,
)

Downloading: "https://github.com/facebookresearch/pytorch_GAN_zoo/zipball/hub" to /Users/adam2392/.cache/torch/hub/hub.zip
Downloading: "https://dl.fbaipublicfiles.com/gan_zoo/PGAN/celebaHQ16_december_s7_i96000-9c72988c.pth" to /Users/adam2392/.cache/torch/hub/checkpoints/celebaHQ16_december_s7_i96000-9c72988c.pth
100%|████████████████████████████████████████████████████████████████████████| 264M/264M [00:16<00:00, 16.8MB/s]


Average network found !


In [20]:
image_save_path = Path("/Users/adam2392/Downloads/Pretrain-Gen/")
image_save_path.mkdir(exist_ok=True)
num_images = 10

noise, _ = gen_model.buildNoiseData(num_images)
with torch.no_grad():
    generated_images = gen_model.test(noise)

# Resize transformation
resize_transform = transforms.Resize((64, 64))

for i in range(num_images):
    # Apply the resize transform
    resized_image = resize_transform(generated_images[i])

    # Normalize the image (clamp and scale to [0, 1])
    image_to_save = (resized_image.clamp(min=-1, max=1) + 1) / 2

    save_image(
        (generated_images[i, :, :, :].clamp(min=-1, max=1) + 1) / 2,
        image_save_path / f"{i + 1}.png",
    )