In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import cv2
import pytorch_lightning as pl
import os
from torch.utils.data import DataLoader
import sys

# Clear any command line arguments to avoid conflict with Jupyter arguments
sys.argv = sys.argv[:1]

# Now you can safely initialize Hydra
import hydra
from omegaconf import DictConfig


In [2]:
from main import ARLDM, DDPStrategy
from datasets.flintstones import StoryDataset

class CustomStory(StoryDataset):
    def __init__(self, args, prompts, subset='test'):
        super(CustomStory, self).__init__(subset, args=args)
        self.prompts = prompts

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, index):
        # Use zeros as placeholders for images
        images = torch.zeros([5, 3, 256, 256])
        source_images = torch.zeros([5, 3, 224, 224])

        # Use the loaded prompts instead of the text from h5
        texts = self.prompts[index]

        # Tokenize caption using CLIPTokenizer
        tokenized = self.clip_tokenizer(
            texts[1:] if self.args.task == 'continuation' else texts,
            padding="max_length",
            max_length=self.max_length,
            truncation=False,
            return_tensors="pt",
        )
        captions, attention_mask = tokenized['input_ids'], tokenized['attention_mask']

        # Tokenize caption using blip tokenizer
        tokenized = self.blip_tokenizer(
            texts,
            padding="max_length",
            max_length=self.max_length,
            truncation=False,
            return_tensors="pt",
        )
        source_caption, source_attention_mask = tokenized['input_ids'], tokenized['attention_mask']

        return images, captions, attention_mask, source_images, source_caption, source_attention_mask


  if not hasattr(numpy, tp_name):
  if not hasattr(numpy, tp_name):
  "lr_options": generate_power_seq(LEARNING_RATE_CIFAR, 11),
  contrastive_task: Union[FeatureMapContrastiveTask] = FeatureMapContrastiveTask("01, 02, 11"),
  self.nce_loss = AmdimNCELoss(tclip)


In [3]:
class LightningDataset(pl.LightningDataModule):
    def __init__(self, args, prompts=None):
        super(LightningDataset, self).__init__()
        self.kwargs = {"num_workers": args.num_workers, "persistent_workers": True if args.num_workers > 0 else False,
                       "pin_memory": True}
        self.args = args
        self.prompts = prompts

    def setup(self, stage='fit'):
        print(self.prompts)
        self.test_data = CustomStory(self.args, prompts=self.prompts)

    def train_dataloader(self):
        if not hasattr(self, 'trainloader'):
            self.trainloader = DataLoader(self.train_data, batch_size=self.args.batch_size, shuffle=True, **self.kwargs)
        return self.trainloader

    def val_dataloader(self):
        if self.val_data is None:
            return None
        return DataLoader(self.val_data, batch_size=self.args.batch_size, shuffle=False, **self.kwargs)

    def test_dataloader(self):
        return DataLoader(self.test_data, batch_size=self.args.batch_size, shuffle=False, **self.kwargs)

    def predict_dataloader(self):
        return DataLoader(self.test_data, batch_size=self.args.batch_size, shuffle=False, **self.kwargs)

    def get_length_of_train_dataloader(self):
        if not hasattr(self, 'trainloader'):
            self.trainloader = DataLoader(self.train_data, batch_size=self.args.batch_size, shuffle=True, **self.kwargs)
        return len(self.trainloader)

In [4]:
prompts = [
    ['Fred and dino are driving a car in a sunny day.', 
     '<char> stopped them on the way.',
     '<char> talked to Fred and dino with angry face because they are driving too fast.',
     'Fred and dino are confused and asked <char> what they should do.',
     '<char> gave them a ticket and told them to report to police station.']
]

def sample_from_prompts(args):
    print(f'loading from checkpoint: {args.test_model_file}')
    assert args.test_model_file is not None, "test_model_file cannot be None"
    dataloader = LightningDataset(args, prompts=prompts)
    dataloader.setup()

    model = ARLDM.load_from_checkpoint(args.test_model_file, args=args, strict=False)

    predictor = pl.Trainer(
        accelerator='gpu',  # 'gpu' is still valid for single GPU
        devices=1,  # Specify 1 GPU, you can also pass an index e.g., [0]
        max_epochs=-1,  # This setting might not be valid for single GPU; set an actual number of epochs
        benchmark=True,  # This can be set to False since benchmarking is more useful for multi-GPU setups
        precision=16  # FP16 precision can still be used for single GPU
    )

    predictions = predictor.predict(model, dataloader)
    images = [elem for sublist in predictions for elem in sublist[0]]

    return images, predictions


In [13]:
results = None

@hydra.main(config_path="/media/mldadmin/home/s123mdg35_05/ar-ldm/", config_name="config")
def main(args):
    global results
    pl.seed_everything(args.seed)
    results = sample_from_prompts(args)
    

In [14]:
if __name__ == "__main__":
    # reset the working directory
    os.chdir(os.path.join(os.getcwd(), '/media/mldadmin/home/s123mdg35_05/ar-ldm/'))
    main()

Global seed set to 0


loading from checkpoint: /media/mldadmin/home/s123mdg35_05/ar-ldm/ckpts/flintstones_ada_pinkhat/epoch=99-step=1300.ckpt
[['Fred and dino are driving a car in a sunny day.', '<char> stopped them on the way.', '<char> talked to Fred and dino with angry face because they are driving too fast.', 'Fred and dino are confused and asked <char> what they should do.', '<char> gave them a ticket and told them to report to police station.']]
clip 4 new tokens added
blip 1 new tokens added
load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth


Using 16bit None Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


[['Fred and dino are driving a car in a sunny day.', '<char> stopped them on the way.', '<char> talked to Fred and dino with angry face because they are driving too fast.', 'Fred and dino are confused and asked <char> what they should do.', '<char> gave them a ticket and told them to report to police station.']]
clip 4 new tokens added
blip 1 new tokens added


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting: 0it [00:00, ?it/s]

In [16]:
images, predictions = results

[<PIL.Image.Image image mode=RGB size=256x256 at 0x7F6BA5729B20>,
 <PIL.Image.Image image mode=RGB size=256x256 at 0x7F6BA57295E0>,
 <PIL.Image.Image image mode=RGB size=256x256 at 0x7F6BA5729C40>,
 <PIL.Image.Image image mode=RGB size=256x256 at 0x7F6BA5729BB0>,
 <PIL.Image.Image image mode=RGB size=256x256 at 0x7F6BA57290D0>]

In [None]:
# display PIL images
def display_images(images):
    fig = plt.figure(figsize=(20, 20))
    columns = 5
    rows = 1
    for i in range(1, columns * rows + 1):
        img = images[i-1]
        fig.add_subplot(rows, columns, i)
        plt.imshow(img)
    plt.show()
    
display_images(images)