In [6]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import numpy as np
import tqdm
import matplotlib.pyplot as plt
import scipy.ndimage as nd
import sys
import cv2
import clip
import torchvision
import os
from copy import copy,deepcopy
from PIL import Image
from PIL import Image, ImageDraw, ImageFont
from matplotlib.lines import Line2D
import jlc
import clip
import torch
import matplotlib.pyplot as plt
from PIL import Image
import sys
ROOT = "C:/Users/Jakob/Desktop/DTU/Computational_Photography/"
sys.path.append(ROOT)

from collage_functions import (CollageTransformer,save_model,load_model,
                               validate,train,loss_with_ignore_index,
                               plot_loss_dict,CollageMaker,
                               CollageDataset,custom_collate_with_info)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [33]:

class CollageDatasetSaver(torch.utils.data.Dataset):
    def __init__(self,
                 collage_sizes = list(range(1,8)),
                 candidate_sizes = [8],
                 positive_candidate_prob = 0.5,
                 collage_type_prob = [0.25,0.25,0.25,0.25],
                 normalize = True,
                 padding_idx = -1,
                 clip_dim = 768,
                 small_group_strategy="concat",
                 min_group_length = None,
                 filter_positives = True,
                 split_start_and_stop = [0,1],
                 allow_dataset_mixing = False,
                 CNN = False,
                 cnn_reshape = [224,224,3]):
        """Dataset for training a CollageTransformer

        Args:
            collage_sizes (list, optional): array to sample the collage size from uniformly. Defaults to list(range(1,8)).
            candidate_sizes (list, optional): array to sample the candidate size from uniformly. Defaults to [8].
            positive_candidate_prob (float, optional): Probability of sampling positive candidates for the collage. Defaults to 0.5.
            collage_type_prob (list, optional): length 4 array or list with probabilities for different collage types in the order ["font","letter","same_prompt","1k"] for datasets [Dafont-free,Dafont-free,Simulacra,LAION-Aesthetics6.5+]. Defaults to [0.25,0.25,0.5,0].
            normalize (bool, optional): Should the clip vectors be L2-normalized. Defaults to True.
            padding_idx (int, optional): padding index used for the torch.tensor entries which are empty since the collage was smaller than the maximum size. Defaults to -1.
            clip_dim (int, optional): Dimension of clip vectors. Defaults to 768.
            small_group_strategy (["concat","ignore"], optional): How to deal with small groups (collages). "concat" combines small groups with other groups. "ignore" only uses single groups - however they can be very small which usually ends in the same images being sampled many times. Defaults to "concat".
            min_group_length (int, optional): Minimum size which groups should have to be included in the dataset. Defaults to None.
            filter_positives (bool, optional): Should positives which are already in the collage not be included in candidates. Defaults to True.
            split_start_and_stop (list, optional): dataset starting and stopping point in terms of ratio of indices to use for dataloader. Defaults to [0,1].
            allow_dataset_mixing (bool, optional): should the model show candidates from different datasets than the sampled used for the collage. Defaults to False.
            CNN (bool, optional): Is a CNN being used (instead of CLIP) to embed images to vectors. Defaults to False.
            cnn_reshape (list, optional): Reshape size for inputs of CNN if a CNN is used for embedding vectors. Defaults to [224,224,3].
        """
        assert small_group_strategy in ["concat","ignore"]
        self.CNN = CNN
        self.allow_dataset_mixing = allow_dataset_mixing
        self.split_start_and_stop = split_start_and_stop
        self.small_group_strategy = small_group_strategy
        self.filter_positives = filter_positives
        self.clip_dim = clip_dim
        self.positive_candidate_prob = positive_candidate_prob
        self.normalize = normalize
        self.padding_idx = padding_idx
        self.cnn_reshape = cnn_reshape

        self.dataset_types = [0,0,1,2]
        self.collage_types = ["font","letter","same_prompt","1k"]

        self.dataset_lengths = [0,0,0]
        self.groups = []
        self.group_types = []
        self.group_lengths = []
        self.image_names = []
        self.collage_type_prob = np.array(collage_type_prob)
        self.collage_type_prob = self.collage_type_prob/(self.collage_type_prob.sum())
        self.collage_sizes = collage_sizes
        max_collage_size = max(collage_sizes)
        self.candidate_sizes = candidate_sizes
        max_candidate_size = max(candidate_sizes)
        self.min_group_length = min_group_length
        if self.min_group_length is None:
            self.min_group_length = max_collage_size//2
            
        self.min_group_size = int(max_collage_size+self.positive_candidate_prob*max_candidate_size)
        self.use_dataset = [False,False,False]
        self.use_groups = []
        self.group_names = []
        self.group_type_lengths = []
        self.clip_matrix = torch.zeros(0,clip_dim)
        self.num_groups = 0
        self.translation = 0
        for i in range(len(self.collage_types)):
            if self.collage_type_prob[i]>0:
                ii = self.dataset_types[i]
                if not self.use_dataset[ii]:
                    self.translation = sum(self.dataset_lengths)
                    self.use_dataset[ii] = True
                    clip_matrix, names = self.get_clip_and_names(self.dataset_types[i]) 
                    self.image_names.extend(names)
                    self.dataset_lengths[ii] = len(names)
                    self.clip_matrix = torch.cat((self.clip_matrix,clip_matrix),axis=0)
                groups, group_names = self.get_groups(names,i,translation=self.translation)
                self.group_names.extend(group_names)
                self.num_groups += len(groups)
                self.groups.extend(groups)
                self.group_types.extend([i for _ in range(len(groups))])
                self.group_lengths.extend([len(g) for g in groups])
                self.group_type_lengths.append(len(groups))
            else:
                self.group_type_lengths.append(0)

        self.group_lengths = np.array(self.group_lengths)
        self.group_types = np.array(self.group_types)

        if normalize:
            self.clip_matrix = torch.nn.functional.normalize(self.clip_matrix,dim=1)

    def save_dataset(self,dataset_idx):
        i = dataset_idx
        start = sum(self.dataset_lengths[:i])
        stop = sum(self.dataset_lengths[:i+1])
        dataset_slice = slice(start,stop,None)
        clip_matrix, image_names = self.get_clip_and_names(dataset_idx)
        save_dict = {"CLIP_matrix": clip_matrix,
                     "image_names": image_names}
        save0 = ROOT+"Deliverables/test_set_data_compressed/"
        if i==0:
            save_path = ROOT+"Deliverables/test_set_data_compressed/Dafonts/"
            dict_save_path = save0+"CLIP_fonts.pth"
        elif i==1:
            save_path = ROOT+"Deliverables/test_set_data_compressed/Simulacra/"
            dict_save_path = save0+"CLIP_sim.pth"
        elif i==2:
            save_path = ROOT+"Deliverables/test_set_data_compressed/LAION_aesthetics_6dot5/"
            dict_save_path = save0+"CLIP_laion.pth"
        torch.save(save_dict,dict_save_path)
        print("saved CLIP for dataset_idx: ",i,clip_matrix.shape)
        for idx in tqdm.tqdm(range(start,stop)):
            if 0<=idx and idx<self.dataset_lengths[0]:
                img_path = ROOT+"d0/dafonts-free-v1/fonts_images/"+self.image_names[idx]+".png"
                img_path_new = save_path+self.image_names[idx]+".png"
                directory = save_path+self.image_names[idx][:self.image_names[idx].find("\\")]
                if not os.path.exists(directory):
                    os.makedirs(directory)
            elif sum(self.dataset_lengths[:1])<=idx and idx<sum(self.dataset_lengths[:2]):
                img_path = ROOT+"d1/sac/"+self.image_names[idx]+".png"
                img_path_new = save_path+self.image_names[idx]+".jpg"
            elif sum(self.dataset_lengths[:2])<=idx and idx<sum(self.dataset_lengths[:3]):
                img_path = ROOT+"d2/"+self.image_names[idx]+".jpg"
                img_path_new = save_path+self.image_names[idx][self.image_names[idx].find("\\")+2:]+".jpg"
            pil_image = Image.open(img_path)
            pil_image.save(img_path_new)

    def get_clip_and_names(self,dataset_idx):
        """loads CLIP matrix and names which define groups in datasets

        Args:
            dataset_idx (int): index of the dataset

        Returns:
            clip_matrix,names: clip matrix and names of images as a list in the same order as the clip matrix
        """
        if dataset_idx==0:
            loaded = torch.load(ROOT+"d0/CLIP_fonts.pth")
        elif dataset_idx==1:
            loaded = torch.load(ROOT+"d1/CLIP_sim.pth")
        elif dataset_idx==2:
            loaded = torch.load(ROOT+"d2/CLIP_laion.pth")
        n = len(loaded["CLIP_matrix"])
        start = max(0,np.floor(self.split_start_and_stop[0]*n).astype(int))
        stop = min(n,np.ceil(self.split_start_and_stop[1]*n).astype(int))
        clip_matrix = loaded["CLIP_matrix"][start:stop].clone()
        names = loaded["image_names"][start:stop]
        return clip_matrix, names

    def get_groups(self,names,group_idx,translation):
        """returns a set of groups (collages) from a given list of names

        Args:
            names (list): list of image names
            group_idx (int): group index
            translation (int): integer translation in terms if image indices, to make sure they are unique

        Returns:
            groups,group_names: list of groups and the name of each associated group
        """
        groups = []
        if group_idx==0: #font
            group_names = np.array([n.split('\\')[0] for n in names])
        elif group_idx==1: #letter
            group_names = np.array([n.split('\\')[1] for n in names])
        elif group_idx==2: #same_prompt
            group_names = np.array([n.split('_')[0] for n in names])
        elif group_idx==3:
            group_names = np.tile(np.arange(10),(1,np.ceil(len(names)/10).astype(int)))[:len(names)].flatten()
        uq,uq_inverse = np.unique(group_names,return_inverse=True)
        groups = [[] for _ in range(len(uq))]
        for i_sample,i_group in enumerate(uq_inverse):
            groups[i_group].append(i_sample+translation)
        
        group_names = [group_names[j] for j in range(len(groups)) if len(groups[j])>=self.min_group_length]
        groups = [g for g in groups if len(g)>=self.min_group_length]
        return groups,group_names

    def get_images(self,image_idx,reshape_size=None,return_torch=False,input_for_reshape=None):
        """return images from a list of image indices

        Args:
            image_idx (Union[int,list]): list of indices to get images for
            reshape_size (list, optional): list,tuple or array to reshape images into. Defaults to None.
            return_torch (bool, optional): should a torch tensor be returned. Defaults to False.
            input_for_reshape (list, optional): list of images to reshape instead of loading images from image_idx. Defaults to None.

        Returns:
            images: images
        """
        if input_for_reshape is None:
            if isinstance(image_idx,int):
                image_idx = [image_idx]
            images = []
            for idx in image_idx:
                if idx<0:
                    continue
                elif 0<=idx and idx<self.dataset_lengths[0]:
                    img_path = ROOT+"d0/dafonts-free-v1/fonts_images/"+self.image_names[idx]+".png"
                elif sum(self.dataset_lengths[:1])<=idx and idx<sum(self.dataset_lengths[:2]):
                    img_path = ROOT+"d1/sac/"+self.image_names[idx]+".png"
                elif sum(self.dataset_lengths[:2])<=idx and idx<sum(self.dataset_lengths[:3]):
                    img_path = ROOT+"d2/"+self.image_names[idx]+".jpg"
                pil_image = Image.open(img_path)
                images.append(np.array(pil_image))
        else:
            images = input_for_reshape

        if reshape_size is not None:
            if isinstance(reshape_size,tuple):
                reshape_size = list(reshape_size)
            if len(reshape_size)<3:
                reshape_size += [3]
            images = [cv2.resize(im, reshape_size[:2], interpolation=cv2.INTER_LINEAR) for im in images]
            for i in range(len(images)):                
                im_size = images[i].shape
                if len(im_size)==2:
                    images[i] = images[i][:,:,None]
                    im_size = images[i].shape
                assert reshape_size[2] in [1,3]
                if reshape_size[2]==3:
                    if im_size[2]==1:
                        images[i]=np.tile(images[i],(1,1,3))
                    elif im_size[2]==4:
                        images[i] = images[i][:3]
                    else:
                        assert im_size[2]==3
                elif reshape_size[2]==1:
                    images[i] = images[i].mean(2,keepdims=True)
        if return_torch and (reshape_size is not None):
            images = torch.stack([torch.from_numpy(im) for im in images],axis=0).permute((0,3,1,2))
        elif return_torch and (reshape_size is None):
            images = [torch.from_numpy(im) for im in images]
        return images
    
    def sample_group_idx(self,group_type):
        """Fast sampling function for getting a group idx

        Args:
            group_type (int): index of group type to sample from

        Returns:
            idx: sampled group idx
        """
        if group_type==0:
            idx = np.random.choice(self.group_type_lengths[0])
        elif group_type>0:
            idx = np.random.choice(self.group_type_lengths[group_type])+sum(self.group_type_lengths[:group_type])
        return idx
    
    def sample_negative_idx(self,group_type,group_indices):
        """Sample a negative index from a group type, but without sampling from indices in group_indices

        Args:
            group_type (int): group type index from which to sample from
            group_indices (list): list of group indices which constitute the positive groups, and therefore not sampled

        Returns:
            idx: sampled negative index
        """
        illegal_indices = sum([self.groups[id] for id in group_indices],[])
        for _ in range(5):
            if self.allow_dataset_mixing:
                idx = np.random.choice(sum(self.dataset_lengths))
            elif self.dataset_types[group_type]==0:
                idx = np.random.choice(self.dataset_lengths[0])
            else:# self.dataset_types[group_type]==1 or 2:
                k = self.dataset_types[group_type]
                idx = np.random.choice(self.dataset_lengths[k])+sum(self.dataset_lengths[:k])
            if idx not in illegal_indices:
                break
        return idx
    
    def __len__(self):
        return sum(self.dataset_lengths)

    def __getitem__(self, idx):
        max_collage_size = max(self.collage_sizes)
        max_candidate_size = max(self.candidate_sizes)

        n_collage = np.random.choice(self.collage_sizes)
        n_candidates = np.random.choice(self.candidate_sizes)

        info = {}
        group_type = np.random.choice(len(self.collage_types),p=self.collage_type_prob)

        info["group_type"] = [group_type,self.collage_types[group_type]]

        if self.small_group_strategy=="concat":
            positives = []
            group_indices = []
            for _ in range(max_collage_size+max_candidate_size):
                group_idx = self.sample_group_idx(group_type)
                group_indices.append(group_idx)
                positives.extend(self.groups[group_idx])

                if len(positives)>self.min_group_size:
                    break
        elif self.small_group_strategy=="ignore":
            group_idx = self.sample_group_idx(group_type)
            positives = self.groups[group_idx]
            group_indices = [group_idx]
        
        collage_idx = self.padding_idx*torch.ones(max_collage_size,dtype=int)
        collage_idx[:n_collage] = torch.from_numpy(np.random.choice(positives, size=n_collage)).long()

        if len(group_indices)>1:
            group_indices_actually_used = []
            for g_idx in group_indices:
                if any([(g_i in collage_idx[:n_collage]) for g_i in self.groups[g_idx]]):
                    group_indices_actually_used.append(g_idx)
            positives = sum([self.groups[id] for id in group_indices_actually_used],[])
            group_indices = group_indices_actually_used
        
        info["group_indices"] = group_indices
        info["group_names"] = [self.group_names[g_idx] for g_idx in group_indices]



        if self.filter_positives:
            positives_tmp = [n for n in positives if n not in collage_idx]
            if len(positives_tmp)>0:
                positives = positives_tmp

        label = self.padding_idx*torch.ones(max_candidate_size,dtype=torch.float32)
        candidate_idx = self.padding_idx*torch.ones(max_candidate_size,dtype=int)
        for i in range(n_candidates):
            if np.random.rand()<self.positive_candidate_prob: 
                candidate_idx[i] = np.random.choice(positives)
                label[i] = 1
            else:
                candidate_idx[i] = self.sample_negative_idx(group_type,group_indices)
                label[i] = 0
        
        if self.CNN:
            candidate = self.padding_idx*torch.ones(max_candidate_size,self.cnn_reshape[2],self.cnn_reshape[0],self.cnn_reshape[1],dtype=torch.float32)
            collage = self.padding_idx*torch.ones(max_collage_size,self.cnn_reshape[2],self.cnn_reshape[0],self.cnn_reshape[1],dtype=torch.float32)
            candidate[:n_candidates] = self.get_images(candidate_idx,reshape_size=self.cnn_reshape,return_torch=True)/255
            collage[:n_collage] = self.get_images(collage_idx,reshape_size=self.cnn_reshape,return_torch=True)/255
        else:
            candidate = self.padding_idx*torch.ones(max_candidate_size,self.clip_dim,dtype=torch.float32)
            collage = self.padding_idx*torch.ones(max_collage_size,self.clip_dim,dtype=torch.float32)
            candidate[:n_candidates] = self.clip_matrix[candidate_idx[:n_candidates]]
            collage[:n_collage] = self.clip_matrix[collage_idx[:n_collage]]

        return collage, candidate, label, collage_idx, candidate_idx, info


In [34]:
dataset = CollageDatasetSaver(split_start_and_stop=[0.975,1.0])
dataset.save_dataset(0)
dataset.save_dataset(1)

saved CLIP for dataset_idx:  0 torch.Size([16233, 768])


100%|██████████| 16233/16233 [00:26<00:00, 614.48it/s]


saved CLIP for dataset_idx:  1 torch.Size([5973, 768])


100%|██████████| 5973/5973 [00:18<00:00, 327.51it/s]


In [35]:
dataset = CollageDatasetSaver(split_start_and_stop=[0.9,1.0])
dataset.save_dataset(2)

saved CLIP for dataset_idx:  2 torch.Size([1000, 768])


100%|██████████| 1000/1000 [00:05<00:00, 188.32it/s]


In [20]:
dataset = CollageDataset(collage_type_prob=[0,0,0,1])
print(dataset.clip_matrix.shape)

torch.Size([1000, 768])
