<a href="https://colab.research.google.com/github/TOTTO27149/BerryIMU/blob/master/ZiCo_Calculation%2C_Poor_model2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Setting things up and helper functions**

In [None]:
from typing import Dict, Tuple
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
import numpy as np
from PIL import Image

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
class ResidualConvBlock(nn.Module):
    def __init__(
        self, in_channels: int, out_channels: int, is_res: bool = False
    ) -> None:
        super().__init__()

        self.same_channels = in_channels == out_channels
        self.is_res = is_res

        # First convolutional layer
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.GELU(),
        )

        # Second convolutional layer
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.GELU(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        # If using residual connection
        if self.is_res:

            x1 = self.conv1(x)
            x2 = self.conv2(x1)

            # If input and output channels are the same, add residual connection directly
            if self.same_channels:
                out = x + x2
            else:
                # If not, apply a 1x1 convolutional layer to match dimensions before adding residual connection
                shortcut = nn.Conv2d(x.shape[1], x2.shape[1], kernel_size=1, stride=1, padding=0).to(x.device)
                out = shortcut(x) + x2
            #print(f"resconv forward: x {x.shape}, x1 {x1.shape}, x2 {x2.shape}, out {out.shape}")

            # Normalize output tensor
            return out / 1.414

        # If not using residual connection, return output of second convolutional layer
        else:
            x1 = self.conv1(x)
            x2 = self.conv2(x1)
            return x2

    # Method to get the number of output channels for this block
    def get_out_channels(self):
        return self.conv2[0].out_channels

    # Method to set the number of output channels for this block
    def set_out_channels(self, out_channels):
        self.conv1[0].out_channels = out_channels
        self.conv2[0].in_channels = out_channels
        self.conv2[0].out_channels = out_channels

class UnetUp(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UnetUp, self).__init__()

        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
            ResidualConvBlock(out_channels, out_channels),
            ResidualConvBlock(out_channels, out_channels),
        ]

        self.model = nn.Sequential(*layers)

    ################### ORIGINAL CODE STARTS ###################################
    # def forward(self, x, skip):
    #     # Concatenate the input tensor x with the skip connection tensor along the channel dimension
    #     x = torch.cat((x, skip), 1)
    #
    #     # Pass the concatenated tensor through the sequential model and return the output
    #     x = self.model(x)
    ################### ORIGINAL CODE ENDS #####################################

    ## NOTICE: I remove the skip connection to impair the model ability [CHANGE 1]
    def forward(self, x):
        x = torch.cat((x, x), 1)
        x = self.model(x)
    ## CHANGE 1 ENDS

        return x


class UnetDown(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UnetDown, self).__init__()

        layers = [ResidualConvBlock(in_channels, out_channels), ResidualConvBlock(out_channels, out_channels), nn.MaxPool2d(2)]

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

class EmbedFC(nn.Module):
    def __init__(self, input_dim, emb_dim):
        super(EmbedFC, self).__init__()
        self.input_dim = input_dim

        layers = [
            nn.Linear(input_dim, emb_dim),
            nn.GELU(),
            nn.Linear(emb_dim, emb_dim),
        ]

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = x.view(-1, self.input_dim)
        return self.model(x)

class CustomDataset(Dataset):
    def __init__(self, sfilename, lfilename, transform, null_context=False):
        self.sprites = np.load(sfilename)
        self.slabels = np.load(lfilename)
        print(f"sprite shape: {self.sprites.shape}")
        print(f"labels shape: {self.slabels.shape}")
        self.transform = transform
        self.null_context = null_context
        self.sprites_shape = self.sprites.shape
        self.slabel_shape = self.slabels.shape

    def __len__(self):
        return len(self.sprites)

    def __getitem__(self, idx):
        if self.transform:
            image = self.transform(self.sprites[idx])
            if self.null_context:
                label = torch.tensor(0).to(torch.int64)
            else:
                label = torch.tensor(self.slabels[idx]).to(torch.int64)
        return (image, label)

    def getshapes(self):
        return self.sprites_shape, self.slabel_shape

transform = transforms.Compose([
    transforms.ToTensor(),                # from [0,255] to range [0.0,1.0]
    transforms.Normalize((0.5,), (0.5,))  # range [-1,1]

])

**Model: U-net**

In [None]:
class ContextUnet(nn.Module):
    def __init__(self, in_channels, n_feat=256, n_cfeat=10, height=28):  # cfeat - context features
        super(ContextUnet, self).__init__()

        self.in_channels = in_channels
        self.n_feat = n_feat
        self.n_cfeat = n_cfeat
        self.h = height  #assume h == w. must be divisible by 4, so 28,24,20,16...

        self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)

        self.down1 = UnetDown(n_feat, n_feat)        # down1 #[10, 256, 8, 8]
        self.down2 = UnetDown(n_feat, 2 * n_feat)    # down2 #[10, 256, 4,  4]

        self.to_vec = nn.Sequential(nn.AvgPool2d((4)), nn.GELU())

        # Embed the timestep and context labels with a one-layer fully connected neural network
        self.timeembed1 = EmbedFC(1, 2*n_feat)
        self.timeembed2 = EmbedFC(1, 1*n_feat)
        self.contextembed1 = EmbedFC(n_cfeat, 2*n_feat)
        self.contextembed2 = EmbedFC(n_cfeat, 1*n_feat)

        self.up0 = nn.Sequential(
            nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, self.h//4, self.h//4), # up-sample
            nn.GroupNorm(8, 2 * n_feat), # normalize
            nn.ReLU(),
        )
        self.up1 = UnetUp(4 * n_feat, n_feat)
        self.up2 = UnetUp(2 * n_feat, n_feat)

        self.out = nn.Sequential(
            nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1), # reduce number of feature maps   #in_channels, out_channels, kernel_size, stride=1, padding=0
            nn.GroupNorm(8, n_feat), # normalize
            nn.ReLU(),
            nn.Conv2d(n_feat, self.in_channels, 3, 1, 1), # map to same number of channels as input
        )

    def forward(self, x, t, c=None): # here it tells how exactly the embedding is performed
        """
        x : (batch, n_feat, h, w) : input image
        t : (batch, n_cfeat)      : time step
        c : (batch, n_classes)    : context label
        """

        x = self.init_conv(x)
        down1 = self.down1(x)       #[10, 256, 8, 8]
        down2 = self.down2(down1)   #[10, 256, 4, 4]

        hiddenvec = self.to_vec(down2)

        if c is None:
            c = torch.zeros(x.shape[0], self.n_cfeat).to(x)

        cemb1 = self.contextembed1(c).view(-1, self.n_feat * 2, 1, 1)     # (batch, 2*n_feat, 1,1)
        temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
        cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1)
        temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)
        #print(f"uunet forward: cemb1 {cemb1.shape}. temb1 {temb1.shape}, cemb2 {cemb2.shape}. temb2 {temb2.shape}")


        up1 = self.up0(hiddenvec)

        ## NOTICE: I remove the skip connection to impair the model ability [CHANGE 2]

        #################### ORIGINAL CODE STARTS ###############################
        # up2 = self.up1(cemb1*up1 + temb1, down2)  # add and multiply embeddings
        # up3 = self.up2(cemb2*up2 + temb2, down1)
        # out = self.out(torch.cat((up3, x), 1))
        #################### ORIGINAL CODE ENDS ###############################

        ## NOTICE: I remove the connectino for U-net to impair the model ability [CHANGE 2]
        up2 = self.up1(cemb1*up1 + temb1)  # add and multiply embeddings
        up3 = self.up2(cemb2*up2 + temb2)
        out = out = self.out(torch.cat((up3, up3), 1))
        ## CHANGE 2 ENDS

        return out


**Sampling**

In [None]:
# hyperparameters

# diffusion hyperparameters
timesteps = 500
beta1 = 1e-4 # hyperparameters for DDPM
beta2 = 0.02 # hyperparameters for DDPM

# network hyperparameters
device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device('cpu'))
n_feat = 64 # 64 hidden dimension feature
n_cfeat = 5 # context vector is of size 5
height = 16 # 16x16 image
save_dir = '/content/drive/MyDrive/How-Diffusion-Models-Work-main/weights/'
saveModel_dir = '/content/drive/MyDrive/How-Diffusion-Models-Work-main/weights/Poor1TrainedByXinda/'

In [None]:
# construct DDPM noise schedule

# these parameters are defined in the DDPM paper; all the parameters in the noise scheduler are to determine what levels of noise to apply to the image at a certain time step
# specifically, the following are the scaling factors s1, s2, s3 - they are calculated here.
b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1
a_t = 1 - b_t
ab_t = torch.cumsum(a_t.log(), dim=0).exp()
ab_t[0] = 1

**Training: initialisation**

In [None]:
# training hyperparameters
batch_size = 100
n_epoch = 32
lrate=1e-3

# load dataset and construct optimizer
dataset = CustomDataset("/content/drive/MyDrive/How-Diffusion-Models-Work-main/sprites_1788_16x16.npy", "/content/drive/MyDrive/How-Diffusion-Models-Work-main/sprite_labels_nc_1788_16x16.npy", transform, null_context=False)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=1)

# construct model
nn_model = ContextUnet(in_channels=3, n_feat=n_feat, n_cfeat=n_cfeat, height=height).to(device)

sprite shape: (89400, 16, 16, 3)
labels shape: (89400, 5)


In [None]:
total_params = sum(p.numel() for p in nn_model.parameters())
print(total_params)

1480771


**Training: collect the gradient**

In [None]:
# training without context code
from collections import defaultdict

In [None]:
# Number of batches N
N = 10
Sample_num = 0
torch.manual_seed(42)

<torch._C.Generator at 0x7cc29dbf8050>

In [None]:
Num_Score = 30

ZiCo_Score = []

for idx in range(Num_Score):
    skip_outer_loop = False  # Initialize flag for skipping the outer loop

    # Initialize dictionary to hold gradients for each batch
    batch_gradients = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))

    # helper function: perturbs an image to a specified noise level
    def perturb_input(x, t, noise):
        return ab_t.sqrt()[t, None, None, None] * x + (1 - ab_t[t, None, None, None]) * noise

    # set into train mode
    optim = torch.optim.Adam(nn_model.parameters(), lr=3e-4)
    nn_model.train()

    batch_idx = 0

    pbar = tqdm(dataloader, mininterval=2 )
    for x, c in pbar:   # x: images  c: context, one-hot encoded vectors

        if skip_outer_loop:  # Check the flag before each inner loop iteration
            break

        if batch_idx >= N:  # Stop after N batches
            break

        optim.zero_grad()
        x = x.to(device)
        c = c.to(x)

        # randomly mask out c
        # Here we create a context mask, with some randomness, we completely mask out the context, so that the model is able to learn generally what a sprite is - this is common for diffusion model.
        context_mask = torch.bernoulli(torch.zeros(c.shape[0]) + 0.9).to(device)
        c = c * context_mask.unsqueeze(-1)

        # perturb data
        noise = torch.randn_like(x)
        t = torch.randint(1, timesteps + 1, (x.shape[0],)).to(device)
        x_pert = perturb_input(x, t, noise)

        # use network to recover noise
        pred_noise = nn_model(x_pert, t / timesteps, c=c)

        # loss is mean squared error between the predicted and true noise
        loss = F.mse_loss(pred_noise, noise)
        loss.backward()

        # Collect gradients
        for layer_name, submodule in nn_model.named_children():
            flattened_gradients_for_layer = []
            for _, param in enumerate(submodule.parameters()):
                # Flatten the gradient tensor to a 1D tensor
                flat_gradients = torch.abs(param.grad).clone().detach().view(-1)
                # Append it to the list for this layer
                flattened_gradients_for_layer.append(flat_gradients)

            if flattened_gradients_for_layer:  # Check if the list is not empty
                # Concatenate all the flattened gradients for this layer into a single 1D tensor
                all_flattened_gradients_for_layer = torch.cat(flattened_gradients_for_layer)
                # Store this tensor in the dictionary
                batch_gradients[batch_idx][layer_name] = all_flattened_gradients_for_layer.tolist()

        batch_idx += 1

    # Calculate expectation and standard deviation for each parameter
    param_stats = defaultdict(lambda: defaultdict(dict))

    num_same = 0


    for layer_name in batch_gradients[0]:
        if skip_outer_loop:  # Check the flag again before this innermost loop
                break

        param_idx = 0
        for gradient in batch_gradients[0][layer_name]:
            all_values = []

            for batch_idx in range(N):
                if batch_gradients[0][layer_name][param_idx] == batch_gradients[1][layer_name][param_idx]:
                    num_same += 1
                all_values.append(batch_gradients[batch_idx][layer_name][param_idx])

            expectation = np.mean(all_values)
            std_dev = np.std(all_values)

            param_stats[layer_name][param_idx]['expectation'] = expectation
            param_stats[layer_name][param_idx]['std_dev'] = std_dev

            if std_dev == 0:
                ZiCo_Score.append(f"Standard deviation for layer {layer_name}, parameter {param_idx} is zero.")
                skip_outer_loop = True  # Set the flag to True
                break  # This will break out of the innermost loop

            param_idx += 1

    if skip_outer_loop:  # Check the flag after exiting the inner loops
        continue  # Skip the rest of this outer loop iteration

    # Initialize a variable to hold the sum of log(sum_layername) across all layers
    ZiCo = 0

    # Loop through each layer_name to calculate sum_layername
    for layer_name in param_stats:
        sum_layername = 0  # Initialize sum_layername for each layer
        # print(3)
        for param_idx in param_stats[layer_name]:
            # Extract the expectation and standard deviation values for each parameter index
            expectation = param_stats[layer_name][param_idx]['expectation']
            std_dev = param_stats[layer_name][param_idx]['std_dev']
            sum_layername += (expectation / std_dev)

        # Take the logarithm of sum_layername and add it to ZiCo
        ZiCo += np.log(sum_layername)

    ZiCo_Score.append(ZiCo)

    Sample_num += 1
    print(f"Iteration {Sample_num} has finished.")

  1%|          | 10/894 [00:09<13:39,  1.08it/s] 


Iteration 1 has finished.


  1%|          | 10/894 [00:01<02:04,  7.08it/s]


Iteration 2 has finished.


  1%|          | 10/894 [00:01<02:06,  6.99it/s]


Iteration 3 has finished.


  1%|          | 10/894 [00:01<02:01,  7.29it/s]


Iteration 4 has finished.


  1%|          | 10/894 [00:02<03:03,  4.81it/s]


Iteration 5 has finished.


  1%|          | 10/894 [00:01<02:09,  6.82it/s]


Iteration 6 has finished.


  1%|          | 10/894 [00:01<02:05,  7.05it/s]


Iteration 7 has finished.


  1%|          | 10/894 [00:01<02:00,  7.32it/s]


Iteration 8 has finished.


  1%|          | 10/894 [00:02<03:03,  4.82it/s]


Iteration 9 has finished.


  1%|          | 10/894 [00:01<02:11,  6.71it/s]


Iteration 10 has finished.


  1%|          | 10/894 [00:01<02:02,  7.21it/s]


Iteration 11 has finished.


  1%|          | 10/894 [00:01<02:02,  7.20it/s]


Iteration 12 has finished.


  1%|          | 10/894 [00:01<02:08,  6.87it/s]


Iteration 13 has finished.


  1%|          | 10/894 [00:01<02:54,  5.08it/s]


Iteration 14 has finished.


  1%|          | 10/894 [00:01<02:02,  7.23it/s]


Iteration 15 has finished.


  1%|          | 10/894 [00:01<02:02,  7.23it/s]


Iteration 16 has finished.


  1%|          | 10/894 [00:01<02:01,  7.30it/s]


Iteration 17 has finished.


  1%|          | 10/894 [00:01<02:02,  7.20it/s]


Iteration 18 has finished.


  1%|          | 10/894 [00:01<02:55,  5.03it/s]


Iteration 19 has finished.


  1%|          | 10/894 [00:01<02:03,  7.15it/s]


Iteration 20 has finished.


  1%|          | 10/894 [00:01<02:01,  7.26it/s]


Iteration 21 has finished.


  1%|          | 10/894 [00:01<02:10,  6.75it/s]


Iteration 22 has finished.


  1%|          | 10/894 [00:01<02:45,  5.35it/s]


Iteration 23 has finished.


  1%|          | 10/894 [00:01<02:17,  6.42it/s]


Iteration 24 has finished.


  1%|          | 10/894 [00:01<02:00,  7.32it/s]


Iteration 25 has finished.


  1%|          | 10/894 [00:01<02:02,  7.21it/s]


Iteration 26 has finished.


  1%|          | 10/894 [00:01<02:07,  6.91it/s]


Iteration 27 has finished.


  1%|          | 10/894 [00:01<02:04,  7.09it/s]


Iteration 28 has finished.


  1%|          | 10/894 [00:02<02:59,  4.92it/s]


Iteration 29 has finished.


  1%|          | 10/894 [00:01<02:02,  7.21it/s]


Iteration 30 has finished.


In [None]:
ZiCo_Score

[126.6309717794907,
 126.59616850635165,
 126.65744594667784,
 126.75171658740372,
 126.76662735653784,
 126.59396248492588,
 126.67090330474868,
 126.5159375882354,
 126.71475676498801,
 126.52840462546531,
 126.6195510285854,
 126.7191755851183,
 126.64170349953936,
 126.56409610434513,
 126.59614821570705,
 126.62585321538967,
 126.78679442382646,
 126.70588113369301,
 126.62811226822836,
 126.6184700657933,
 126.65342180420137,
 126.74812635916408,
 126.59162682568703,
 126.76746834491836,
 126.68480256769463,
 126.64805655767945,
 126.73639584790186,
 126.57909452074313,
 126.66200466903139,
 126.62607122341197]