In [1]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import torch
import os
from denoising_diffusion_pytorch import Unet1D, GaussianDiffusion1D, Trainer1D, Dataset1D

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
proj_dir = Path("..") / "Data/r77-mini-data-fortnight"
print(proj_dir.joinpath("input"))

data_dir = proj_dir
input_dir = data_dir.joinpath("input")
fixed_input_dir = input_dir.joinpath("fixed")
temporal_input_dir = input_dir.joinpath("temporal")
target_dir = data_dir.joinpath("target")
fixed_target_dir = target_dir.joinpath("fixed")
temporal_target_dir = target_dir.joinpath("temporal")

../Data/r77-mini-data-fortnight/input


In [3]:
files = list(temporal_target_dir.iterdir())


index_dir = Path("..")/"Index"
files_index = list(index_dir.iterdir())
indices = np.load(files_index[0]).squeeze()

#First hour of data
index = indices[0]
tt = np.load(files[index]).squeeze()
tt.shape

(144, 70, 100, 3)

In [4]:
files_ti = list(temporal_input_dir.iterdir())

In [5]:
#all diff data
data = []

for j in range(360):
    index = indices[j]

    tt = np.load(files[index]).squeeze()
    tt = np.transpose(tt, (0, 2, 3, 1))[:, :, :, :64]  #shape: (144, 100, 3, 64)

    ti = np.load(files_ti[index]).squeeze()
    ti = np.transpose(ti, (0, 2, 1))[:, :, :64]  #shape: (144, 3, 64)

    #ensure ti has the same second dimension as tt
    ti_expanded = np.repeat(ti[:, None, :, :], tt.shape[1], axis=1)  #shape: (144, 100, 3, 64)

    #calculate diff and reshape
    diff_data = tt - ti_expanded  #shape: (144, 100, 3, 64)
    reshaped_tt = diff_data.reshape(-1, 3, 64)  #shape: (14400, 3, 64)

    data.append(reshaped_tt)


data = np.concatenate(data).reshape(-1, 3, 64)
data.shape

(5184000, 3, 64)

In [6]:
def normalise(vector):
    min_val = np.min(vector)
    
    max_val = np.max(vector)
    normalised_vector = (vector - min_val) / (max_val - min_val)
    return normalised_vector, min_val, max_val

In [7]:
files_fixed = list(fixed_input_dir.iterdir())
files_ft = list(fixed_target_dir.iterdir())
#std dev calculated for each grid using fxed target vectors, combined with fixed input vector to give three fixed variables for each grid
#fixed variables all (fixed input repeated 100 times for each grid)

cond = []

for j in range(360):
    
    index = indices[j]
    fi = np.load(files_fixed[index]).squeeze()
    ft = np.load(files_ft[index]).squeeze()

    stdev_orog = np.array([np.std(i) for i in ft[:,:,1]]).reshape(-1,1)
    fixed_input = np.hstack((fi, stdev_orog))
    reshaped_fi = np.repeat(fixed_input, 100, axis=0)
    cond.append(reshaped_fi)


cond = np.concatenate(cond).reshape(-1, 3)
cond.shape

(5184000, 3)

In [8]:
lsf = cond[:,0]
orog = cond[:,1]
stdev_orog = cond[:,2]

In [9]:
#all data
tt_all = np.stack([
    np.transpose(np.load(files[indices[i]]).squeeze(), (0, 2, 3, 1))[:,:,:,:64].reshape((14400, 3, 64), order = 'C')
    for i in range(360)
])

all_data = tt_all.reshape(-1,3,64)
all_data.shape

(5184000, 3, 64)

In [10]:
#has inversion
def has_inversion(temperature_profile, troposphere_height):
    gradient = np.gradient(temperature_profile[:troposphere_height])
    return np.any(np.array(gradient) > 0.0)

temp = all_data[:,2,:]


inversion = np.array([has_inversion(i, 40) for i in temp])
inversion[inversion == False].shape[0]

1220648

In [11]:
indices1 = np.where((inversion == False))[0]
data1 = data[indices1]

indices2 = np.where((inversion == True))[0]
data2 = data[indices2]

print(f'no inversion training data {data1.shape}')
print(f'has inversion training data {data2.shape}')


no inversion training data (1220648, 3, 64)
has inversion training data (3963352, 3, 64)


In [12]:
indices1 = np.where((inversion == False) & (lsf <= 0.2))[0]
trainingdata1 = data[indices1]

indices2 = np.where((inversion == False) & (lsf == 1) & (orog < 0.06))[0]
trainingdata2 = data[indices2]

indices3 = np.where((inversion == False) & (lsf == 1) & (orog > 0.06))[0]
trainingdata3 = data[indices3]

indices4 = np.where((inversion == False) & (lsf < 1) & (lsf > 0.2) & (orog < 0.03))[0]
trainingdata4 = data[indices4]

indices5 = np.where((inversion == False) & (lsf < 1) & (lsf > 0.2) & (orog > 0.03))[0]
trainingdata5 = data[indices5]

indices6 = np.where((inversion == True) & (lsf <= 0.2))[0]
trainingdata6 = data[indices6]

indices7 = np.where((inversion == True) & (lsf == 1) & (orog < 0.06))[0]
trainingdata7 = data[indices7]

indices8 = np.where((inversion == True) & (lsf == 1) & (orog > 0.06))[0]
trainingdata8 = data[indices8]

indices9 = np.where((inversion == True) & (lsf < 1) & (lsf > 0.2) & (orog < 0.03))[0]
trainingdata9 = data[indices9]

indices10 = np.where((inversion == True) & (lsf < 1) & (lsf > 0.2) & (orog > 0.03))[0]
trainingdata10 = data[indices10]

In [13]:
print(f'class1 training data {trainingdata1.shape}, {round(trainingdata1.shape[0]/5184000, 3)*100}')
print(f'class2 training data {trainingdata2.shape},{round(trainingdata2.shape[0]/5184000, 3)*100}')
print(f'class3 training data {trainingdata3.shape}, {round(trainingdata3.shape[0]/5184000, 3)*100}')
print(f'class4 training data {trainingdata4.shape}, {round(trainingdata4.shape[0]/5184000, 3)*100}')
print(f'class5 training data {trainingdata5.shape}, {round(trainingdata5.shape[0]/5184000, 3)*100}')
print(f'class6 training data {trainingdata6.shape}, {round(trainingdata6.shape[0]/5184000, 3)*100}')
print(f'class7 training data {trainingdata7.shape}, {round(trainingdata7.shape[0]/5184000, 3)*100}')
print(f'class8 training data {trainingdata8.shape}, {round(trainingdata8.shape[0]/5184000, 3)*100}')
print(f'class9 training data {trainingdata9.shape}, {round(trainingdata9.shape[0]/5184000, 3)*100}')
print(f'class10 training data {trainingdata10.shape}, {round(trainingdata10.shape[0]/5184000, 3)*100}')

class1 training data (645213, 3, 64), 12.4
class2 training data (114201, 3, 64),2.1999999999999997
class3 training data (60029, 3, 64), 1.2
class4 training data (266266, 3, 64), 5.1
class5 training data (134939, 3, 64), 2.6
class6 training data (2018787, 3, 64), 38.9
class7 training data (425799, 3, 64), 8.200000000000001
class8 training data (191971, 3, 64), 3.6999999999999997
class9 training data (885734, 3, 64), 17.1
class10 training data (441061, 3, 64), 8.5


In [16]:
data_list = [trainingdata1, trainingdata2, trainingdata3,trainingdata4,trainingdata5,trainingdata6, trainingdata7, trainingdata8,trainingdata9,trainingdata10 ]
sum([i.shape[0] for i in data_list]) == data.shape[0]

True

In [18]:

for idx, trainingdata in enumerate(data_list, start = 1): 
    
    model = Unet1D(
        dim = 64,
        dim_mults = (1, 2, 4, 8),
        channels = 3
    )

    diffusion = GaussianDiffusion1D(
        model,
        seq_length = 64,
        timesteps = 100,
        objective = 'pred_v'
    )

    data = normalise(trainingdata)[0]
    training_seq =  torch.from_numpy(data)

    trainer = Trainer1D(
        diffusion,
        dataset = training_seq,
        train_batch_size = 10, #set batch size here (take 100 samples, one grid)
        train_lr = 1e-4,
        train_num_steps = 100000,         # total training steps (1000)
        gradient_accumulate_every = 2,    # gradient accumulation steps
        ema_decay = 0.995,                # exponential moving average decay
        amp = True,                       # turn on mixed precision
    )
    trainer.train()

    # after a lot of training

    sampled_seq = diffusion.sample(batch_size = 10000)
    torch.save(sampled_seq, f"sampled_seq{idx}.pt") 

dataloader_config = DataLoaderConfiguration(split_batches=True)
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 111.97it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 131.64it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 123.19it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 142.42it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 142.50it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 141.93it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 141.46it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 142.60it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 142.25it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 141.72it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 142.26it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 142.06it/s]
sampling loop time step:

training complete


sampling loop time step: 100%|██████████| 100/100 [00:23<00:00,  4.27it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 148.36it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 148.37it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 147.91it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 148.65it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 148.57it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 146.94it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 142.53it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 143.28it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 143.33it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 143.79it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 143.60it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 143.23it/s]
sampling loop

training complete


sampling loop time step: 100%|██████████| 100/100 [00:23<00:00,  4.28it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 143.40it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 143.62it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 143.42it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 147.32it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 145.61it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 147.25it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 143.06it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 144.78it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 142.83it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 144.93it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 145.25it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 145.30it/s]
sampling loop

training complete


sampling loop time step: 100%|██████████| 100/100 [00:23<00:00,  4.27it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 141.05it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 141.61it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 140.27it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 141.87it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 141.86it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 141.72it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 141.25it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 140.70it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 140.65it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 146.20it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 131.84it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 131.16it/s]
sampling loop

training complete


sampling loop time step: 100%|██████████| 100/100 [00:23<00:00,  4.27it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 142.60it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 142.90it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 142.64it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 144.19it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 143.30it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 142.66it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 145.56it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 144.53it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 144.70it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 147.09it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 147.04it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 146.35it/s]
sampling loop

training complete


sampling loop time step: 100%|██████████| 100/100 [00:23<00:00,  4.27it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 137.13it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 138.71it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 141.21it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 142.39it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 141.59it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 140.81it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 129.12it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 129.41it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 129.89it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 145.21it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 144.89it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 140.93it/s]
sampling loop

training complete


sampling loop time step: 100%|██████████| 100/100 [00:23<00:00,  4.27it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 143.37it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 143.56it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 143.19it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 145.73it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 146.00it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 145.40it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 145.39it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 144.49it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 145.35it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 144.49it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 144.54it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 143.91it/s]
sampling loop

training complete


sampling loop time step: 100%|██████████| 100/100 [00:23<00:00,  4.27it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 130.56it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 131.41it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 131.41it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 144.11it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 143.03it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 143.29it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 145.79it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 146.06it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 146.19it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 142.73it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 143.15it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 142.63it/s]
sampling loop

training complete


sampling loop time step: 100%|██████████| 100/100 [00:23<00:00,  4.27it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 141.51it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 142.11it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 141.77it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 142.18it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 142.04it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 141.99it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 148.32it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 148.85it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 148.36it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 148.40it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 148.32it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 147.93it/s]
sampling loop

training complete


sampling loop time step: 100%|██████████| 100/100 [00:23<00:00,  4.28it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 144.90it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 145.00it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 144.58it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 145.62it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 145.85it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 145.39it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 147.98it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 148.04it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 147.56it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 144.33it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 145.82it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 145.53it/s]
sampling loop

training complete


sampling loop time step: 100%|██████████| 100/100 [00:23<00:00,  4.28it/s]


In [19]:
sampled_sequences_list = []

for i in range(1, 11):  
    sampled_seq = torch.load(f"sampled_seq{i}.pt")
    sampled_sequences_list.append(sampled_seq.cpu().numpy())

sampled_sequences = np.array(sampled_sequences_list)

In [20]:
folder = 'Samples_cond_diffs_A'
if not os.path.exists(folder):
    os.makedirs(folder)

file_path = os.path.join(folder, 'sample.npy')
np.save(file_path, sampled_sequences)

In [11]:
indices1 = np.where((inversion == False) & (lsf <= 0.2))[0]
trainingdata1 = data[indices1]

indices2 = np.where((inversion == False) & (lsf == 1) & (stdev_orog < 0.025))[0]
trainingdata2 = data[indices2]

indices3 = np.where((inversion == False) & (lsf == 1) & (stdev_orog > 0.025))[0]
trainingdata3 = data[indices3]

indices4 = np.where((inversion == False) & (lsf < 1) & (lsf > 0.2) & (stdev_orog < 0.03))[0]
trainingdata4 = data[indices4]

indices5 = np.where((inversion == False) & (lsf < 1) & (lsf > 0.2) & (stdev_orog > 0.03))[0]
trainingdata5 = data[indices5]

indices6 = np.where((inversion == True) & (lsf <= 0.2))[0]
trainingdata6 = data[indices6]

indices7 = np.where((inversion == True) & (lsf == 1) & (stdev_orog < 0.025))[0]
trainingdata7 = data[indices7]

indices8 = np.where((inversion == True) & (lsf == 1) & (stdev_orog > 0.025))[0]
trainingdata8 = data[indices8]

indices9 = np.where((inversion == True) & (lsf < 1) & (lsf > 0.2) & (stdev_orog < 0.03))[0]
trainingdata9 = data[indices9]

indices10 = np.where((inversion == True) & (lsf < 1) & (lsf > 0.2) & (stdev_orog > 0.03))[0]
trainingdata10 = data[indices10]

In [12]:
print(f'class1 training data {trainingdata1.shape}, {round(trainingdata1.shape[0]/5184000, 3)*100}')
print(f'class2 training data {trainingdata2.shape},{round(trainingdata2.shape[0]/5184000, 3)*100}')
print(f'class3 training data {trainingdata3.shape}, {round(trainingdata3.shape[0]/5184000, 3)*100}')
print(f'class4 training data {trainingdata4.shape}, {round(trainingdata4.shape[0]/5184000, 3)*100}')
print(f'class5 training data {trainingdata5.shape}, {round(trainingdata5.shape[0]/5184000, 3)*100}')
print(f'class6 training data {trainingdata6.shape}, {round(trainingdata6.shape[0]/5184000, 3)*100}')
print(f'class7 training data {trainingdata7.shape}, {round(trainingdata7.shape[0]/5184000, 3)*100}')
print(f'class8 training data {trainingdata8.shape}, {round(trainingdata8.shape[0]/5184000, 3)*100}')
print(f'class9 training data {trainingdata9.shape}, {round(trainingdata9.shape[0]/5184000, 3)*100}')
print(f'class10 training data {trainingdata10.shape}, {round(trainingdata10.shape[0]/5184000, 3)*100}')

class1 training data (645213, 3, 64), 12.4
class2 training data (96400, 3, 64),1.9
class3 training data (77830, 3, 64), 1.5
class4 training data (278033, 3, 64), 5.4
class5 training data (123172, 3, 64), 2.4
class6 training data (2018787, 3, 64), 38.9
class7 training data (360200, 3, 64), 6.9
class8 training data (257570, 3, 64), 5.0
class9 training data (928967, 3, 64), 17.9
class10 training data (397828, 3, 64), 7.7


In [13]:
data_list = [trainingdata1, trainingdata2, trainingdata3,trainingdata4,trainingdata5,trainingdata6, trainingdata7, trainingdata8,trainingdata9,trainingdata10 ]
sum([i.shape[0] for i in data_list]) == data.shape[0]

True

In [14]:

for idx, trainingdata in enumerate(data_list, start = 1): 
    
    model = Unet1D(
        dim = 64,
        dim_mults = (1, 2, 4, 8),
        channels = 3
    )

    diffusion = GaussianDiffusion1D(
        model,
        seq_length = 64,
        timesteps = 100,
        objective = 'pred_v'
    )

    data = normalise(trainingdata)[0]
    training_seq =  torch.from_numpy(data)

    trainer = Trainer1D(
        diffusion,
        dataset = training_seq,
        train_batch_size = 10, #set batch size here (take 100 samples, one grid)
        train_lr = 1e-4,
        train_num_steps = 10000,         # total training steps (1000)
        gradient_accumulate_every = 2,    # gradient accumulation steps
        ema_decay = 0.995,                # exponential moving average decay
        amp = True,                       # turn on mixed precision
    )
    trainer.train()

    # after a lot of training

    sampled_seq = diffusion.sample(batch_size = 10000)
    torch.save(sampled_seq, f"sampled_seq{idx}.pt")

dataloader_config = DataLoaderConfiguration(split_batches=True)
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 112.79it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 132.23it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 124.57it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 146.63it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 146.53it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 146.15it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 144.15it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 143.83it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 143.71it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 131.61it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 131.49it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 131.39it/s]
sampling loop time step:

training complete


sampling loop time step: 100%|██████████| 100/100 [00:23<00:00,  4.26it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 144.42it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 144.43it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 144.60it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 141.82it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 141.68it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 141.55it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 141.83it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 141.17it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 141.71it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 142.00it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 141.50it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 140.39it/s]
sampling loop

training complete


sampling loop time step: 100%|██████████| 100/100 [00:23<00:00,  4.27it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 141.91it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 140.96it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 141.88it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 143.10it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 143.25it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 143.12it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 144.71it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 142.16it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 144.51it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 130.07it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 130.56it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 130.46it/s]
sampling loop

training complete


sampling loop time step: 100%|██████████| 100/100 [00:23<00:00,  4.27it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 128.78it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 129.17it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 128.98it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 144.34it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 143.13it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 142.86it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 127.33it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 128.10it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 127.73it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 130.92it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 131.61it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 131.35it/s]
sampling loop

training complete


sampling loop time step: 100%|██████████| 100/100 [00:23<00:00,  4.27it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 145.79it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 145.75it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 145.86it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 145.93it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 145.89it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 145.67it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 140.87it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 141.16it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 140.86it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 128.56it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 129.65it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 129.37it/s]
sampling loop

training complete


sampling loop time step: 100%|██████████| 100/100 [00:23<00:00,  4.27it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 147.49it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 147.62it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 147.78it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 141.10it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 140.96it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 140.24it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 147.48it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 140.63it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 130.65it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 144.04it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 145.30it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 145.05it/s]
sampling loop

training complete


sampling loop time step: 100%|██████████| 100/100 [00:23<00:00,  4.27it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 140.68it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 140.57it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 140.51it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 130.38it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 131.04it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 130.53it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 130.49it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 131.14it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 130.70it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 130.22it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 130.84it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 130.82it/s]
sampling loop

training complete


sampling loop time step: 100%|██████████| 100/100 [00:23<00:00,  4.27it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 141.52it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 142.77it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 142.31it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 141.58it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 141.35it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 140.29it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 143.44it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 143.38it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 143.37it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 144.57it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 142.99it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 144.68it/s]
sampling loop

training complete


sampling loop time step: 100%|██████████| 100/100 [00:23<00:00,  4.27it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 146.61it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 146.83it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 146.78it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 142.17it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 142.29it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 141.80it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 129.30it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 129.10it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 129.09it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 142.53it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 141.64it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 142.40it/s]
sampling loop

training complete


sampling loop time step: 100%|██████████| 100/100 [00:23<00:00,  4.27it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 142.87it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 144.32it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 144.06it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 145.47it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 145.28it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 145.12it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 144.32it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 144.40it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 143.98it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 145.10it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 141.72it/s]
sampling loop time step: 100%|██████████| 100/100 [00:00<00:00, 139.45it/s]
sampling loop

training complete


sampling loop time step: 100%|██████████| 100/100 [00:23<00:00,  4.27it/s]


In [15]:
sampled_sequences_list = []

for i in range(1, 11):  
    sampled_seq = torch.load(f"sampled_seq{i}.pt")
    sampled_sequences_list.append(sampled_seq.cpu().numpy())

sampled_sequences = np.array(sampled_sequences_list)

In [16]:
folder = 'Samples_cond_diffs_C'
if not os.path.exists(folder):
    os.makedirs(folder)

file_path = os.path.join(folder, 'sample.npy')
np.save(file_path, sampled_sequences)

: 