# Imports & load dataset

In [65]:
%load_ext autoreload
%autoreload 2

from torch.nn.utils import prune
import zipfile
from airfoil_diffusion.airfoil_datasets import *
from airfoil_diffusion.networks import *
from airfoil_diffusion.trainer import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [66]:
if not os.path.exists("./datasets/1_parameter/data/"):
    files=[file for file in os.listdir("./datasets/1_parameter/") if file.endswith(".zip")]
    for file in tqdm(files): 
        f=zipfile.ZipFile("./datasets/1_parameter/"+file,'r')
        for file in f.namelist():
            f.extract(file,"./datasets/1_parameter/data/")
        f.close() 

train_dataset = AirfoilDataset(FileDataFiles("./datasets/1_parameter/train_cases.txt",base_path="./datasets/1_parameter/data/"),
                               data_size=32)

Loading data:  23%|██▎       | 29/125 [00:00<00:00, 282.93it/s]

Loading data: 100%|██████████| 125/125 [00:00<00:00, 293.81it/s]


# Train model

In [67]:
network_configs = {
    "attention_layers": [2, 3],
    "condition_layers": [-2],
    "depth_each_layer": 2,
    "dim_basic": 16,
    "dim_condition": 3,
    "dim_encoded_time": 8,
    "dim_in": 3,
    # "dim_multipliers": [1, 2, 4, 4],
    "dim_multipliers": [2, 2, 2, 2],
    "dim_out": 3,
    "heads_attention": 4,
    "linear_attention": False,
    "skip_connection_scale": 0.707,
    "use_input_condition": True
}

In [101]:
network = AifNet(**network_configs)
network.show_current_configs()

attention_layers: [2, 3]
condition_layers: [-2]
depth_each_layer: 2
dim_basic: 16
dim_condition: 3
dim_encoded_time: 8
dim_in: 6
dim_multipliers: [2, 2, 2, 2]
dim_out: 3
heads_attention: 4
linear_attention: False
skip_connection_scale: 0.707
use_input_condition: True
condition_dim: 0


In [69]:
diffusion_trainer = DiffusionTrainer()
train_configs = {
    "name": "training",
    "save_path": "./training/single_parameter/32/",
    "device": "cuda:0",
    "batch_size_train": len(train_dataset),
    "shuffle_train": True,
    "num_workers_train": 0,
    "validation_epoch_frequency": 0,
    "optimizer": "AdamW",
    "lr_scheduler": "step",
    "warmup_epoch": 0,
    "record_iteration_loss": False,
    "epochs": 125000,
    "save_epoch": 5000,
    "lr": 0.0001,
    "final_lr": 0.00001
}

In [None]:
diffusion_trainer.train_from_scratch(network, train_dataset, **train_configs)

In [70]:
def evaluate_sparcity(model):
    total_params = 0
    total_zeros = 0

    for name, module in model.named_modules():
        if hasattr(module, "weight"):
            num_params = sum(p.numel() for p in module.parameters())
            total_params += num_params
            total_zeros += torch.sum(module.weight == 0)

    return total_params, total_zeros.item()


In [71]:
def _prune(network: nn.Module, prune_type:str, pruning_percentage:float):
    norm_n = {'L1': 1, 'L2': 2}[prune_type]

    for module in network.modules():
        if type(module) is nn.Conv2d and module.out_channels > 3:
            prune.ln_structured(module, 'weight', amount=pruning_percentage, dim=0, n=norm_n)


def prune_remove(network):
    for module in network.modules():
        if type(module) is nn.Conv2d and module.out_channels > 3:
            prune.remove(module, 'weight')


In [102]:
network = AifNet(**network_configs)
evaluate_sparcity(network)

(426194, 0)

In [None]:
_prune(network, "L1", 0.5)

In [77]:
total_w, n_zeros = evaluate_sparcity(network)
f"Sparse: {100 * n_zeros / total_w:.2f}%"

'Sparse: 44.73%'

In [87]:
0.92**8

0.5132188731375618

In [103]:
for _ in range(8):
    _prune(network, "L1", (1 - 0.92))

In [104]:
total_w, n_zeros = evaluate_sparcity(network)
f"Sparse: {100 * n_zeros / total_w:.2f}%"

'Sparse: 44.24%'