In [2]:
from dquartic.utils import data_loader
# from dquartic.model.unet1d import UNet1d
from dquartic.utils.data_loader import DIAMSDataset
import torch
import torch.quantization

import torch
from torch.utils.data import Dataset, DataLoader


In [61]:
import gc
gc.collect(


    
    
)

14076

In [56]:

ms1_file = "npy/ms1_data_int32.npy"
ms2_file = "npy/ms2_data_cat_int32.npy"


In [57]:
dataset = DIAMSDataset(ms1_file=ms1_file, ms2_file=ms2_file, normalize="minmax")
data_loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1)
data_iter = iter(data_loader)
ms2_1, ms1_1, ms2_2,ms1_2 = next(data_iter)


Info: Loaded 520 MS2 slice samples and 520 MS1 slice samples from NPY files.


In [5]:
model_init={
              "dim": 4,
      "channels": 1,
      "dim_mults": [
        1,
        2,
        2,
        2,
        4,
        4,
        4
      ],
      "conditional": True,
      "init_cond_channels": 1,
      "attn_cond_channels": 1,
      "tfer_dim_mult": 620,
      "downsample_dim": 40000,
      "simple": True
}

In [None]:

import math
from collections import namedtuple
from functools import partial, wraps

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.ao.quantization as tq

from einops import rearrange
from packaging import version

from rotary_embedding_torch import RotaryEmbedding


AttentionConfig = namedtuple(
    "AttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"]
)

def exists(x):
    return x is not None

def default(val, d):
    return val if exists(val) else (d() if callable(d) else d)

def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

print_once = once(print)

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

def Upsample(dim, dim_out=None):
    return nn.Sequential(
        nn.Upsample(scale_factor=2, mode="nearest"),
        nn.Conv1d(dim, default(dim_out, dim), 3, padding=1),
    )

def Downsample(dim, dim_out=None):
    return nn.Conv1d(dim, default(dim_out, dim), 4, 2, 1)

class RMSNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.g = nn.Parameter(torch.ones(1, dim, 1))
    def forward(self, x):
        return F.normalize(x, dim=1) * self.g * (x.shape[1] ** 0.5)

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = RMSNorm(dim)
    def forward(self, x, *args, **kwargs):
        x = self.norm(x)
        return self.fn(x, *args, **kwargs)


class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim, theta=10000):
        super().__init__()
        self.dim = dim
        self.theta = theta
    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(self.theta) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


class Block(nn.Module):
    def __init__(self, dim, dim_out, dropout=0.0):
        super().__init__()
        self.proj = nn.Conv1d(dim, dim_out, 3, padding=1)
        self.norm = RMSNorm(dim_out)
        self.act = nn.SiLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, scale_shift=None):
        x = self.proj(x)
        x = self.norm(x)
        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift
        x = self.act(x)
        return self.dropout(x)

class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out, *, time_emb_dim=None, dropout=0.0):
        super().__init__()
        self.mlp = (
            nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2))
            if exists(time_emb_dim)
            else None
        )
        self.block1 = Block(dim, dim_out, dropout=dropout)
        self.block2 = Block(dim_out, dim_out)
        self.res_conv = nn.Conv1d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb=None):
        scale_shift = None
        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, "b c -> b c 1")
            scale_shift = time_emb.chunk(2, dim=1)

        h = self.block1(x, scale_shift=scale_shift)
        h = self.block2(h)
        return h + self.res_conv(x)


class Attend(nn.Module):
    def __init__(self, dropout=0.0, flash=False, scale=None):
        super().__init__()
        self.dropout = dropout
        self.scale = scale
        self.attn_dropout = nn.Dropout(dropout)
        self.flash = flash
        # We require PyTorch 2.0 or higher for built-in SDPA kernel
        assert not (flash and version.parse(torch.__version__) < version.parse("2.0.0")), \
            "Flash attention requires PyTorch 2.0 or higher."

        # CPU config
        self.cpu_config = AttentionConfig(True, True, True)
        self.cuda_config = None

        if torch.cuda.is_available() and flash:
            device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
            device_version = version.parse(f"{device_properties.major}.{device_properties.minor}")
            if device_version > version.parse("8.0"):
                print_once("A100 GPU detected, using flash attention if input tensor is on cuda")
                self.cuda_config = AttentionConfig(True, False, False)
            else:
                print_once("Non-A100 GPU, using math or mem efficient attention if on cuda")
                self.cuda_config = AttentionConfig(False, True, True)

    def flash_attn(self, q, k, v):
        is_cuda = q.is_cuda
        if exists(self.scale):
            default_scale = q.shape[-1]
            q = q * (self.scale / default_scale)

        q, k, v = map(lambda t: t.contiguous(), (q, k, v))
        config = self.cuda_config if is_cuda else self.cpu_config

        with nn.attention.sdpa_kernel(**config._asdict()):
            out = F.scaled_dot_product_attention(
                q, k, v, dropout_p=self.dropout if self.training else 0.0
            )
        return out

    def forward(self, q, k, v):
        if self.flash:
            return self.flash_attn(q, k, v)

        scale = default(self.scale, q.shape[-1] ** -0.5)
        sim = torch.einsum("b h i d, b h j d -> b h i j", q, k) * scale
        attn = sim.softmax(dim=-1)
        attn = self.attn_dropout(attn)
        out = torch.einsum("b h i j, b h j d -> b h i d", attn, v)
        return out

class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Sequential(nn.Conv1d(hidden_dim, dim, 1), RMSNorm(dim))

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(lambda t: rearrange(t, "b (h c) n -> b h c n", h=self.heads), qkv)

        q = q.softmax(dim=-2)
        k = k.softmax(dim=-1)
        q = q * self.scale
        context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
        out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
        out = rearrange(out, "b h c n -> b (h c) n", h=self.heads)
        return self.to_out(out)

class Attention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32, flash=False, use_xattn=False, cond_dim=1):
        super().__init__()
        self.use_xattn = use_xattn
        self.heads = heads
        hidden_dim = dim_head * heads
        self.rotary_emb = RotaryEmbedding(dim=dim_head // 2)
        self.attend = Attend(flash=flash)

        if self.use_xattn:
            self.to_qv = nn.Conv1d(dim, hidden_dim * 2, 1, bias=False)
            self.to_k = nn.Conv1d(cond_dim, hidden_dim, 1, bias=False)
        else:
            self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias=False)

        self.to_out = nn.Conv1d(hidden_dim, dim, 1)

    def forward(self, x, cond=None):
        if self.use_xattn and exists(cond):
            qv = self.to_qv(x).chunk(2, dim=1)
            q, v = map(lambda t: rearrange(t, "b (h c) n -> b h n c", h=self.heads), qv)
            k = rearrange(self.to_k(cond), "b (h c) n -> b h n c", h=self.heads)
        else:
            qkv = self.to_qkv(x).chunk(3, dim=1)
            q, k, v = map(lambda t: rearrange(t, "b (h c) n -> b h n c", h=self.heads), qkv)

        q = self.rotary_emb.rotate_queries_or_keys(q)
        k = self.rotary_emb.rotate_queries_or_keys(k)
        out = self.attend(q, k, v)
        out = rearrange(out, "b h n d -> b (h d) n")
        return self.to_out(out)

class HybridSelfAndCrossAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32, flash=False, cond_dim=1):
        super().__init__()
        self.heads = heads
        hidden_dim = dim_head * heads
        self.rotary_emb = RotaryEmbedding(dim=dim_head // 2)
        self.attend = Attend(flash=flash)

        self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias=False)
        self.to_qv = nn.Conv1d(dim, hidden_dim * 2, 1, bias=False)
        self.to_k = nn.Conv1d(cond_dim, hidden_dim, 1, bias=False)

        self.to_mid = nn.Conv1d(hidden_dim, dim, 1)
        self.to_out = nn.Conv1d(hidden_dim, dim, 1)

    def forward(self, x, cond):
        # self-attn
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(lambda t: rearrange(t, "b (h c) n -> b h n c", h=self.heads), qkv)
        q = self.rotary_emb.rotate_queries_or_keys(q)
        k = self.rotary_emb.rotate_queries_or_keys(k)
        x = self.attend(q, k, v)
        x = rearrange(x, "b h n d -> b (h d) n")
        mid = self.to_mid(x)

        # cross-attn
        qv = self.to_qv(mid).chunk(2, dim=1)
        q, v = map(lambda t: rearrange(t, "b (h c) n -> b h n c", h=self.heads), qv)
        k = rearrange(self.to_k(cond), "b (h c) n -> b h n c", h=self.heads)
        q = self.rotary_emb.rotate_queries_or_keys(q)
        k = self.rotary_emb.rotate_queries_or_keys(k)
        out = self.attend(q, k, v)
        out = rearrange(out, "b h n d -> b (h d) n")
        return self.to_out(out)


class ConditionalScaleShift(nn.Module):
    def __init__(self, time_emb_dim, dim):
        super().__init__()
        self.to_scale_shift = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim * 2))
    def forward(self, x, t):
        scale, shift = self.to_scale_shift(t).chunk(2, dim=-1)
        return x * (scale + 1) + shift

class LayerNorm1d(nn.Module):
    def __init__(self, channels, *, bias=True, eps=1e-5):
        super().__init__()
        self.bias = bias
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, channels, 1))
        self.b = nn.Parameter(torch.zeros(1, channels, 1)) if bias else None

    def forward(self, x):
        var = torch.var(x, dim=1, unbiased=False, keepdim=True)
        mean = torch.mean(x, dim=1, keepdim=True)
        norm = (x - mean) * (var + self.eps).rsqrt() * self.g
        return norm + self.b if self.bias else norm

class FeedForward1d(nn.Module):
    def __init__(self, channels, ch_mult=2):
        super().__init__()
        self.net = nn.Sequential(
            LayerNorm1d(channels=channels),
            nn.Conv1d(channels, channels * ch_mult, 1),
            nn.GELU(),
            nn.Conv1d(channels * ch_mult, channels, 1),
        )
    def forward(self, x):
        return self.net(x)


class Transformer1d(nn.Module):
    def __init__(
        self, 
        dim, 
        depth=4, 
        heads=4, 
        dim_head=32, 
        mlp_dim=None, 
        use_xattn=False, 
        cond_dim=1
    ):
        super().__init__()
        self.layers = nn.ModuleList([])
        for i in range(depth):
            # for simplicity, we alternate attention types
            if i < depth // 2 or not use_xattn:
                self.layers.append(
                    nn.ModuleList([
                        Attention(dim, heads=heads, dim_head=dim_head),
                        FeedForward1d(dim, default(mlp_dim, dim * 2)),
                    ])
                )
            else:
                self.layers.append(
                    nn.ModuleList([
                        HybridSelfAndCrossAttention(dim, heads=heads, dim_head=dim_head, cond_dim=cond_dim),
                        FeedForward1d(dim, default(mlp_dim, dim * 2)),
                    ])
                )

    def forward(self, x, cond=None):
        # We stored them as [attn, ff], but let's carefully apply them:
        for block in self.layers:
            attn, ff = block
            x = attn(x, cond=cond) + x
            x = ff(x) + x
        return x



class UNet1d(nn.Module):
    """
    A 1D U-Net model for diffusion-based tasks. 
    (With some custom attention, RMSNorm, etc.)
    """
    def __init__(
        self,
        dim,
        init_dim=None,
        out_dim=None,
        dim_mults=(1, 2, 4, 8),
        channels=3,
        dropout=0.0,
        conditional=True,
        init_cond_channels=None,
        attn_cond_channels=None,
        attn_cond_init_dim=None,
        learned_variance=False,
        sinusoidal_pos_emb_theta=10000,
        attn_heads=4,
        attn_dim_head=32,
        tfer_dim_mult=620,
        tfer_depth=4,
        downsample_dim=40000,
        simple=True,
        pos_output_only=False,
    ):
        super().__init__()
        self.channels = channels
        self.conditional = conditional
        input_channels = channels + default(init_cond_channels, 0)
        init_dim = default(init_dim, dim)
        self.init_conv = nn.Conv1d(input_channels, init_dim, 7, padding=3)

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        # time embeddings
        time_dim = dim * 4
        sinu_pos_emb = SinusoidalPosEmb(dim, theta=sinusoidal_pos_emb_theta)
        self.time_mlp = nn.Sequential(
            sinu_pos_emb, 
            nn.Linear(dim, time_dim), 
            nn.GELU(), 
            nn.Linear(time_dim, time_dim)
        )

        resnet_block = partial(ResnetBlock, time_emb_dim=time_dim, dropout=dropout)

        # conditioning signals
        if self.conditional:
            self.init_cond_proj = ConditionalScaleShift(
                time_emb_dim=time_dim, 
                dim=default(init_cond_channels, 1)
            )
            attn_cond_init_dim = default(attn_cond_init_dim, dim * 2)
            if simple:
                self.attn_cond_proj = nn.ModuleList([
                    nn.Identity(),
                    nn.Sequential(
                        nn.Conv1d(attn_cond_channels, attn_cond_init_dim, 7, padding=3),
                        nn.GELU(),
                        nn.Conv1d(attn_cond_init_dim, attn_cond_init_dim, 1),
                    ),
                ])
            else:
                self.attn_cond_proj = nn.ModuleList([
                    nn.Sequential(
                        nn.Conv1d(attn_cond_channels, attn_cond_init_dim, 7, padding=3),
                        resnet_block(attn_cond_init_dim, attn_cond_init_dim),
                        resnet_block(attn_cond_init_dim, attn_cond_init_dim),
                        Residual(
                            PreNorm(attn_cond_init_dim, LinearAttention(attn_cond_init_dim))
                        ),
                    ),
                    Transformer1d(
                        attn_cond_init_dim * tfer_dim_mult,
                        depth=tfer_depth // 2,
                        heads=attn_heads,
                        dim_head=attn_dim_head,
                    ),
                ])

        # layers (Downsampling)
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (len(in_out) - 1)
            self.downs.append(
                nn.ModuleList([
                    resnet_block(dim_in, dim_in),
                    resnet_block(dim_in, dim_in),
                    Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                    (Downsample(dim_in, dim_out) if not is_last 
                     else nn.Conv1d(dim_in, dim_out, 3, padding=1)),
                ])
            )

        # mid
        self.downsampled_n = downsample_dim // (2 ** (len(dim_mults) - 1))
        mid_dim = dims[-1]
        self.mid_block1 = resnet_block(mid_dim * self.downsampled_n, mid_dim * self.downsampled_n)
        if simple:
            self.mid_attn = Residual(
                PreNorm(
                    mid_dim * self.downsampled_n,
                    Attention(
                        mid_dim * self.downsampled_n,
                        heads=attn_heads,
                        dim_head=attn_dim_head,
                        use_xattn=self.conditional,
                        cond_dim=attn_cond_init_dim,
                    ),
                )
            )
        else:
            self.mid_attn = Residual(
                PreNorm(
                    mid_dim * self.downsampled_n,
                    Transformer1d(
                        mid_dim * self.downsampled_n,
                        depth=tfer_depth,
                        heads=attn_heads,
                        dim_head=attn_dim_head,
                        use_xattn=self.conditional,
                        cond_dim=attn_cond_init_dim,
                    ),
                )
            )
        self.mid_block2 = resnet_block(mid_dim * self.downsampled_n, mid_dim * self.downsampled_n)

        # Upsampling
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
            is_last = ind == (len(in_out) - 1)
            self.ups.append(
                nn.ModuleList([
                    resnet_block(dim_out + dim_in, dim_out),
                    resnet_block(dim_out + dim_in, dim_out),
                    Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                    (Upsample(dim_out, dim_in) if not is_last 
                     else nn.Conv1d(dim_out, dim_in, 3, padding=1)),
                ])
            )

        # final
        default_out_dim = channels * (1 if not learned_variance else 2)
        self.out_dim = default(out_dim, default_out_dim)
        self.final_res_block = resnet_block(init_dim * 2, init_dim)
        self.final_conv = nn.Conv1d(init_dim, self.out_dim, 1)
        self.final_act = nn.Softplus() if pos_output_only else nn.Identity()

    def forward(self, x, time, init_cond=None, attn_cond=None):
        b = x.shape[0] if x.dim() == 3 else 1
        if x.dim() == 3:
            # shape: b, rt, mz -> (b rt), 1, mz
            x = rearrange(x, "b rt mz -> (b rt) 1 mz")
        else:
            # shape: rt, mz -> rt, 1, mz
            x = rearrange(x, "rt mz -> rt 1 mz")

        t = self.time_mlp(time)

        if self.conditional:
            init_cond = default(init_cond, lambda: torch.zeros_like(x))
            if init_cond.dim() == 3:
                init_cond = rearrange(init_cond, "b rt mz -> (b rt) 1 mz")
            else:
                init_cond = rearrange(init_cond, "rt mz -> rt 1 mz")
            init_cond = self.init_cond_proj(init_cond, t)
            x = torch.cat((init_cond, x), dim=1)

        x = self.init_conv(x)
        r = x.clone()

        if self.conditional:
            attn_cond = default(attn_cond, lambda: torch.zeros_like(x))
            # Example shape handling for cond
            if attn_cond.dim() == 2:
                attn_cond = rearrange(attn_cond, "b rt -> (b rt) 1 1")
            else:
                attn_cond = rearrange(attn_cond, "b rt mz -> (b rt) 1 mz")

            mz_net, rt_net = self.attn_cond_proj
            attn_cond = mz_net(attn_cond)
            attn_cond = rearrange(attn_cond, "(b rt) d mz -> b (d mz) rt", b=b)
            attn_cond = rt_net(attn_cond)

        # Down pass
        h = []
        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            h.append(x)
            x = block2(x, t)
            x = attn(x)
            h.append(x)
            x = downsample(x)

        # mid
        x = rearrange(x, "(b rt) d mz -> b (d mz) rt", b=b)
        x = self.mid_block1(x, t)
        x = self.mid_attn(x, cond=attn_cond if self.conditional else None)
        x = self.mid_block2(x, t)
        x = rearrange(x, "b (d mz) rt -> (b rt) d mz", mz=self.downsampled_n)

        # Up pass
        for block1, block2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim=1)
            x = block1(x, t)
            x = torch.cat((x, h.pop()), dim=1)
            x = block2(x, t)
            x = attn(x)
            x = upsample(x)

        x = torch.cat((x, r), dim=1)
        x = self.final_res_block(x, t)
        x = self.final_conv(x)
        x = rearrange(x, "(b rt) d mz -> b (rt d) mz", b=b)
        return self.final_act(x)

#      WRAPPER FOR QUANTIZATION (QuantStub/DeQuantStub)
# only gonna quantise supported layers so doesnt matter 
class QuantizedUNet1dWrapper(nn.Module):
    """
    Wrap the UNet1d so that we can do static quantization (Eager Mode).
    This inserts:
      - QuantStub at input
      - DeQuantStub at output
    Then we let PyTorchâ€™s built-in prepare() and convert() do the rest.
    """
    def __init__(self, unet: nn.Module):
        super().__init__()
        self.quant = tq.QuantStub()
        self.model = unet
        self.dequant = tq.DeQuantStub()

    def forward(self, x, time, init_cond=None, attn_cond=None):
        # Quantize the input
        x = self.quant(x)
        # Forward pass through your original U-Net
        out = self.model(x, time, init_cond, attn_cond)
        # Dequantize the output
        out = self.dequant(out)
        return out





In [None]:
    # 1) Build original UNet
unet = UNet1d(
        **model_init
        # ...
    ).to('cuda')

model = QuantizedUNet1dWrapper(unet)
import torch
import torch.ao.quantization as tq

# Force per-tensor affine for both activation & weight
per_tensor_qconfig = tq.QConfig(
    activation=tq.HistogramObserver.with_args(
        qscheme=torch.per_tensor_affine,
        reduce_range=False
    ),
    weight=tq.HistogramObserver.with_args(
        qscheme=torch.per_tensor_affine,
        dtype=torch.qint8,
        reduce_range=False
    ),
)

model.qconfig = per_tensor_qconfig

    # 3) Set QConfig (here, "fbgemm" for x86 CPU with AVX2)
# model.qconfig = tq.get_default_qconfig("fbgemm")



In [25]:
tq.prepare(model, inplace=True)

    # 5) Calibrate
    # You need some representative data to run through the model in eval mode.
model.eval()

QuantizedUNet1dWrapper(
  (quant): QuantStub(
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (model): UNet1d(
    (init_conv): Conv1d(
      2, 4, kernel_size=(7,), stride=(1,), padding=(3,)
      (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
    )
    (time_mlp): Sequential(
      (0): SinusoidalPosEmb()
      (1): Linear(
        in_features=4, out_features=16, bias=True
        (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
      )
      (2): GELU(approximate='none')
      (3): Linear(
        in_features=16, out_features=16, bias=True
        (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
      )
    )
    (init_cond_proj): ConditionalScaleShift(
      (to_scale_shift): Sequential(
        (0): SiLU()
        (1): Linear(
          in_features=16, out_features=2, bias=True
          (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
        )
      )
 

In [None]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"Parameter: {name}, Value: {param.data}")

Parameter: model.init_conv.weight, Value: tensor([[[-0.1598,  0.2116, -0.2439,  0.0442,  0.0973, -0.1735,  0.1961],
         [-0.0621,  0.1701, -0.0554,  0.0978,  0.0060, -0.0044,  0.1072]],

        [[ 0.2516,  0.1461, -0.2599,  0.2404, -0.2043,  0.2586, -0.2442],
         [ 0.2524,  0.0733, -0.0006,  0.1177, -0.0262,  0.0623, -0.0380]],

        [[-0.0697,  0.0896,  0.1742,  0.2116, -0.0023, -0.2453, -0.0146],
         [-0.1664, -0.1529,  0.1832, -0.0101,  0.2246, -0.1536, -0.2020]],

        [[ 0.1580,  0.0400, -0.2502,  0.2476, -0.1447, -0.1175, -0.1621],
         [ 0.0909,  0.0851, -0.2309,  0.0562,  0.0414,  0.0204, -0.2634]]],
       device='cuda:0')
Parameter: model.init_conv.bias, Value: tensor([-0.1324,  0.1871,  0.2254, -0.1771], device='cuda:0')
Parameter: model.time_mlp.1.weight, Value: tensor([[ 0.0625, -0.2258, -0.1206, -0.1757],
        [ 0.4835, -0.2103, -0.2763, -0.2715],
        [-0.4072,  0.2352, -0.1956, -0.4070],
        [-0.1788,  0.3326, -0.2065, -0.3320],
     

In [None]:
from torchinfo import summary
summary(model)#its the quantised model

Layer (type:depth-idx)                                       Param #
QuantizedUNet1dWrapper                                       --
â”œâ”€Quantize: 1-1                                              --
â”œâ”€UNet1d: 1-2                                                --
â”‚    â””â”€Conv1d: 2-1                                           --
â”‚    â””â”€Sequential: 2-2                                       --
â”‚    â”‚    â””â”€SinusoidalPosEmb: 3-1                            --
â”‚    â”‚    â””â”€Linear: 3-2                                      --
â”‚    â”‚    â””â”€GELU: 3-3                                        --
â”‚    â”‚    â””â”€Linear: 3-4                                      --
â”‚    â””â”€ConditionalScaleShift: 2-3                            --
â”‚    â”‚    â””â”€Sequential: 3-5                                  --
â”‚    â””â”€ModuleList: 2-4                                       --
â”‚    â”‚    â””â”€Identity: 3-6                                    --
â”‚    â”‚    â””â”

In [35]:
m = UNet1d(**model_init)
summary(m)


Layer (type:depth-idx)                             Param #
UNet1d                                             --
â”œâ”€Conv1d: 1-1                                      60
â”œâ”€Sequential: 1-2                                  --
â”‚    â””â”€SinusoidalPosEmb: 2-1                       --
â”‚    â””â”€Linear: 2-2                                 80
â”‚    â””â”€GELU: 2-3                                   --
â”‚    â””â”€Linear: 2-4                                 272
â”œâ”€ConditionalScaleShift: 1-3                       --
â”‚    â””â”€Sequential: 2-5                             --
â”‚    â”‚    â””â”€SiLU: 3-1                              --
â”‚    â”‚    â””â”€Linear: 3-2                            34
â”œâ”€ModuleList: 1-4                                  --
â”‚    â””â”€Identity: 2-6                               --
â”‚    â””â”€Sequential: 2-7                             --
â”‚    â”‚    â””â”€Conv1d: 3-3                            64
â”‚    â”‚    â””â”€GELU: 3-4                   

In [None]:
import torch
import numpy as np

import torch.ao.quantization as tq
torch.serialization.add_safe_globals([np.dtype])
from torch.serialization import add_safe_globals

add_safe_globals([np.core.multiarray.scalar])




In [12]:
path='dquartic/model/best_model.pth'
checkpoint = torch.load(path, map_location="cuda",weights_only=False)



In [26]:
if "model_state_dict" in checkpoint:
    state_dict = checkpoint["model_state_dict"]
    # Load the state dict into the model
    model.load_state_dict(state_dict,strict=False)
else:
    print("Error: Checkpoint doesn't contain 'model_state_dict'")

# # 3) Assign a quantization config
# model.qconfig = tq.get_default_qconfig("fbgemm")


# tq.prepare(model, inplace=True)

In [27]:
# 5) Calibrate with a small sample of data from your dataloader
torch.manual_seed(42)
device='cuda'
# model_prepared.eval()
with torch.no_grad():
    calibration_samples = 0
    for i, (ms2_1, ms1_1, ms2_2, ms1_2) in enumerate(data_loader):
        if i >= 2:  # Limit to a small number of batches for calibration
            break
            
        # Move tensors to the right device
        ms2_1 = ms2_1.to(device)
        ms1_1 = ms1_1.to(device)
        ms2_2 = ms2_2.to(device)
        
        # Create mixture conditioning (as done in training)
        mixture_weights = (0.5, 0.5)
        ms2_cond = (ms2_1 * mixture_weights[0]) + (ms2_2 * mixture_weights[1])
        
        # Create random timesteps for batch
        batch_size = ms2_1.shape[0]
        timestep = torch.randint(0, 1000, (batch_size,), device=device)
        
        # Add noise to ms2_1 to simulate diffusion process
        noise = torch.randn_like(ms2_1, device=device)
        alpha = 0.7  # Using a fixed alpha for simplicity
        noisy_ms2 = math.sqrt(alpha) * ms2_1 + math.sqrt(1-alpha) * noise
        
        # Pass through model for calibration
        _ = model(noisy_ms2, timestep, ms2_cond, ms1_1)
        
        calibration_samples += len(ms2_1)
        print(f"Calibrated with {calibration_samples} samples so far")

# 6) Convert the calibrated model to quantized form


Calibrated with 1 samples so far
Calibrated with 2 samples so far


In [28]:

model_int8 = tq.convert(model, inplace=False)
    
    # 7) Save the quantized weights with all checkpoint components
quantized_checkpoint = {
        'epoch': checkpoint['epoch'],
        'model_state_dict': model_int8.state_dict(),
        'optimizer_state_dict': checkpoint['optimizer_state_dict'],
        'scheduler_state_dict': checkpoint['scheduler_state_dict'],
        'best_loss': checkpoint['best_loss']
    }
torch.save(quantized_checkpoint, "my_unet_checkpoint_int8.pth")
print("Successfully saved quantized model")


Successfully saved quantized model
