In [32]:
import torch
import os
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from scipy import ndimage as nd
import tqdm

assert torch.cuda.is_available()

In [112]:
def get_dummy_batch(im_size=256, batch_size=32, feature_dim1=76, feature_dim2=32,num_classes=3,
                    noisyness_image=0.95,
                    noisy_features=0.5):
    
    ground_truth = np.random.rand(batch_size, num_classes, im_size, im_size)
    #blur ground truth and make into class pixel map:
    sigma = im_size/10
    ground_truth = nd.gaussian_filter(ground_truth, sigma=(0,0,sigma,sigma))
    ground_truth[:,0] += ground_truth.std()*1.0
    ground_truth = np.argmax(ground_truth, axis=1)[:,None]
    image = np.random.rand(batch_size, 1, im_size, im_size)
    image = noisyness_image*image + (1-noisyness_image)*ground_truth/num_classes
    features = np.random.rand(batch_size, 1, feature_dim1, feature_dim2)
    ground_truth_resized = torch.nn.functional.interpolate(torch.tensor(ground_truth).float(), 
                                                           size=(feature_dim1, feature_dim2), 
                                                           mode="nearest").numpy()
    features = noisy_features*features + (1-noisy_features)*ground_truth_resized/num_classes
    return features, image, ground_truth

num_dummy_datapoints = 32
batch = get_dummy_batch(batch_size=num_dummy_datapoints)

#save data in ./data/features, ./data/images, ./data/ground_truth

matplotlib_palette = [0,0,0]+sum([[int(round(c2*255)) for c2 in c] for c in plt.get_cmap("tab20").colors][::2],[])
folder_names = ["features", "images", "ground_truth"]
for i in range(num_dummy_datapoints):
    for j in range(len(batch)):
        save_path = os.path.join("./data", folder_names[j], f"{i:06d}.png")
        Path(save_path).parent.mkdir(parents=True, exist_ok=True)
        mult = 255 if j < 2 else 1
        img = Image.fromarray((batch[j][i][0]*mult).astype(np.uint8))
        #put pallete if ground truth
        if j == 2:
            img = img.convert("P", colors=3)
            img.putpalette(matplotlib_palette)
        img.save(save_path)
    

In [56]:
#create dataloader for the structure of the data
import torch
from torchvision import transforms

class Dataset(torch.utils.data.Dataset):
    def __init__(self, root_dir="./data", im_reshape=(64,64)):
        self.root_dir = root_dir
        self.im_reshape = im_reshape
        self.folder_names = ["features", "images", "ground_truth"]
        self.filenames = []
        for path in list((Path(root_dir)/ self.folder_names[0]).glob("*")):
            self.filenames.append(path.name)
        
    def __len__(self):
        return len(self.filenames)
    
    def __getitem__(self, idx):
        filename = self.filenames[idx]
        for folder_name in self.folder_names:
            path = Path(self.root_dir)/folder_name/filename
            if folder_name == self.folder_names[0]:
                #load features
                features = torch.from_numpy(np.array(Image.open(path))/255).float().unsqueeze(0)
            elif folder_name == self.folder_names[1]:
                #load image
                image = torch.from_numpy(np.array(Image.open(path))/255).float().unsqueeze(0).unsqueeze(0)
            elif folder_name == self.folder_names[2]:
                #load ground truth
                ground_truth = torch.from_numpy(np.array(Image.open(path))).float().unsqueeze(0).unsqueeze(0)
        image = torch.nn.functional.interpolate(image, size=self.im_reshape, mode="area").squeeze(0)
        ground_truth = torch.nn.functional.interpolate(ground_truth, size=self.im_reshape, mode="area").long().squeeze(0)
        sample = [features,image,ground_truth]
        return sample

dataset = Dataset()
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(0.8*len(dataset)), len(dataset)-int(0.8*len(dataset))])
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=True)

batch = next(iter(train_dataloader))

print("features shape:", batch[0].shape)
print("image shape:", batch[1].shape)
print("ground truth shape:", batch[2].shape)


features shape: torch.Size([4, 1, 76, 32])
image shape: torch.Size([4, 1, 64, 64])
ground truth shape: torch.Size([4, 1, 64, 64])


In [18]:
from diffusers import UNet2DConditionModel
import torch

unet = UNet2DConditionModel(block_out_channels=[32,32,32,32],encoder_hid_dim=64,cross_attention_dim = 64, in_channels=1, out_channels=1)
def number_of_parameters(model):
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"model has {num_params} trainable parameters")

number_of_parameters(unet)
unet.to("cuda")
hidden_states = torch.randn(4, 3, 64).cuda()
timesteps = torch.tensor(0).long().cuda()
image = torch.randn(4, 1, 128, 64).cuda()
print("image shape:", image.shape)
out = unet(image, timesteps, hidden_states)
print("out shape:", out["sample"].shape)

model has 1121889 trainable parameters
image shape: torch.Size([4, 1, 128, 64])
out shape: torch.Size([4, 1, 128, 64])


In [24]:
def convert_unet_to_downnet(unet):
    unet.up_blocks = torch.nn.ModuleList()
    unet.conv_norm_out = None
    unet.conv_out = torch.nn.Identity()
    return unet

unet = UNet2DConditionModel(block_out_channels=[32,32,32,32],encoder_hid_dim=64,cross_attention_dim = 64, in_channels=1, out_channels=1)
def number_of_parameters(model):
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"model has {num_params} trainable parameters")

number_of_parameters(unet)
unet = convert_unet_to_downnet(unet)
number_of_parameters(unet)
unet.to("cuda")
hidden_states = torch.randn(4, 3, 64).cuda()
timesteps = torch.tensor(0).long().cuda()
image = torch.randn(4, 1, 128, 64).cuda()
print("image shape:", image.shape)
out = unet(image, timesteps, hidden_states)
print("out shape:", out["sample"].shape)

model has 1121889 trainable parameters
model has 457216 trainable parameters
image shape: torch.Size([4, 1, 128, 64])
out shape: torch.Size([4, 32, 16, 8])


In [71]:
from argparse import Namespace

def get_default_args():
    args = Namespace()
    #training args
    args.batch_size = 4
    args.lr = 0.001
    args.num_epochs = 10
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #model args
    args.im_size=64
    args.feature_dim1=76
    args.feature_dim2=32
    args.block_out_channels = [32,32,64,64,128] #how large the matrices are in the unet
    args.num_classes = 3 # number of blobs+ background class

    #args you probably dont want to change
    args.encoder_hid_dim = 64
    args.cross_attention_dim = 64
    args.down_block_types = ["DownBlock2D", "DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"]
    args.up_block_types = ["UpBlock2D", "UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"]
    args.mid_block_type = "UNetMidBlock2DCrossAttn"
    # One of ["concat","embed","simple_x_attn","x_attn"] where each is more complex than the one before. Not sure what is best
    args.conditioning_mode = "x_attn"
    return args

def modify_block_names(block_names,conditioning_mode):
    if conditioning_mode in ["concat","embed"]:
        block_names = [x.replace("SimpleCrossAttn","") for x in block_names]
        block_names = [x.replace("CrossAttn","") for x in block_names]
    elif conditioning_mode == "simple_x_attn":
        block_names = [x.replace("CrossAttn","SimpleCrossAttn") for x in block_names]
    else:
        assert conditioning_mode == "x_attn"
    return block_names

def is_power_of_2(num):
    power = np.log2(num)
    is_power = np.isclose(power,np.round(power))
    return np.ceil(power).astype(int),is_power

class ToreNet(torch.nn.Module):
    def __init__(self,args):
        super().__init__()
        self.args = args

        down_block_types = args.down_block_types
        while len(args.block_out_channels) != len(down_block_types):
            down_block_types = [args.down_block_types[0]] + args.down_block_types
        up_block_types = args.up_block_types
        while len(args.block_out_channels) != len(up_block_types):
            up_block_types = [args.up_block_types[0]] + args.up_block_types
        if args.conditioning_mode in ["concat","embed"]:
            mid_block_type = None
        elif args.conditioning_mode == "simple_x_attn":
            mid_block_type = args.mid_block_type.replace("CrossAttn","SimpleCrossAttn")
        elif args.conditioning_mode == "x_attn":
            mid_block_type = args.mid_block_type
        down_block_types = modify_block_names(down_block_types,args.conditioning_mode)
        up_block_types = modify_block_names(up_block_types,args.conditioning_mode)
        self.image_unet = UNet2DConditionModel(block_out_channels=args.block_out_channels,
                                         encoder_hid_dim=args.encoder_hid_dim,
                                         cross_attention_dim=args.cross_attention_dim,
                                         down_block_types=down_block_types,
                                         up_block_types=up_block_types,
                                         mid_block_type=mid_block_type,
                                         in_channels=2 if args.conditioning_mode == "concat" else 1, 
                                         out_channels=args.num_classes,
                                         addition_embed_type="text" if args.conditioning_mode == "embed" else None)
        assert args.im_size % 2**(len(args.block_out_channels)-1) == 0, "image size must be divisible by 2**num_down_blocks"
        if args.conditioning_mode in ["embed","simple_x_attn","x_attn"]:
            down_block_types = modify_block_names(down_block_types,"embed")
            up_block_types = modify_block_names(up_block_types,"embed")
            self.feature_downnet = UNet2DConditionModel(block_out_channels=args.block_out_channels[:-1]+[args.encoder_hid_dim],
                                         down_block_types=down_block_types,
                                         up_block_types=up_block_types,
                                         mid_block_type=None,
                                         in_channels=2 if args.conditioning_mode == "concat" else 1, 
                                         out_channels=args.num_classes)
            self.feature_unet = convert_unet_to_downnet(self.feature_downnet)
            
            #add so we can accept strange feature sizes
            power_ceil,is_power = is_power_of_2(args.feature_dim1)
            self.feature_dim1_linear = torch.nn.Identity() if is_power else torch.nn.Linear(args.feature_dim1, power_ceil)
            power_ceil,is_power = is_power_of_2(args.feature_dim2)
            self.feature_dim2_linear = torch.nn.Identity() if is_power else torch.nn.Linear(args.feature_dim2, power_ceil)

    def forward(self, image, features):
        dummy_timesteps = torch.tensor(0).long().cuda() #added simply because this library is made for diffusion models and we dont want to rewrite their code
        if self.args.conditioning_mode == "concat":
            features = torch.nn.functional.interpolate(features, size=(image.shape[2],image.shape[3]), mode="area")
            x = torch.concat([image, features], dim=1)
            x = self.image_unet(x, dummy_timesteps, hidden_states)
        elif self.args.conditioning_mode in ["embed","simple_x_attn","x_attn"]:
            features = self.feature_dim2_linear(features)
            features = self.feature_dim1_linear(features.permute(0,1,3,2)).permute(0,1,3,2)
            features = self.feature_downnet(features, dummy_timesteps, encoder_hidden_states=None)["sample"]
            encoder_hidden_states = torch.nn.functional.avg_pool2d(features, kernel_size=(features.shape[2],features.shape[3]))
            encoder_hidden_states = encoder_hidden_states[:,None,:,0,0]
            x = self.image_unet(image, dummy_timesteps, encoder_hidden_states=encoder_hidden_states)
        else:
            raise NotImplementedError(self.args.conditioning_mode)
        return x["sample"]

args = get_default_args()
args.conditioning_mode = "simple_x_attn"
model = ToreNet(args)
model.to("cuda")

number_of_parameters(model)
features, image, ground_truth = next(iter(train_dataloader))
features, image, ground_truth = (features.to(args.device), 
                                    image.to(args.device), 
                                    ground_truth.to(args.device))
print("image shape:", image.shape)
out = model(image, features)
print("out shape:", out.shape)

model has 5425502 trainable parameters
image shape: torch.Size([4, 1, 64, 64])
out shape: torch.Size([4, 3, 64, 64])


In [None]:
def train_loop(model, optimizer, train_dataloader, val_dataloader, args):
    model.train()
    pbar = tqdm.tqdm(range(args.num_epochs)*len(train_dataloader))
    pbar.set_description("training")
    for epoch in range(args.num_epochs):
        for i, batch in enumerate(train_dataloader):
            optimizer.zero_grad()
            features, image, ground_truth = batch
            features, image, ground_truth = (features.to(args.device), 
                                    image.to(args.device), 
                                    ground_truth.to(args.device))
            pred = model(image, features)
            loss = torch.nn.functional.cross_entropy(pred, ground_truth)
            loss.backward()
            optimizer.step()
            if i % 10 == 0:
                print(f"epoch {epoch}, iter {i}, loss {loss.item()}")
        with torch.no_grad():
            model.eval()
            for i, batch in enumerate(val_dataloader):
                features, image, ground_truth = batch
                features, image, ground_truth = (features.to(args.device), 
                                                 image.to(args.device), 
                                                 ground_truth.to(args.device))
                pred = model(image, features)
                loss = torch.nn.functional.cross_entropy(pred, ground_truth)
                val_loss = loss.item()
                if i % 10 == 0:
                    print(f"epoch {epoch}, iter {i}, val loss {val_loss}")
            model.train()

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
train_loop(model, optimizer, train_dataloader, val_dataloader, args)