In [1]:
import torch
import torch.nn as nn

import torch.nn.functional as F
import math
from collections import defaultdict

from timm.models.layers import trunc_normal_
from timm.models.layers import DropPath
from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import _load_weights

import torch
import torch.nn as nn
from einops import rearrange
from pathlib import Path

import torch.nn.functional as F

from timm.models.layers import DropPath


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout, out_dim=None):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.act = nn.GELU()
        if out_dim is None:
            out_dim = dim
        self.fc2 = nn.Linear(hidden_dim, out_dim)
        self.drop = nn.Dropout(dropout)

    @property
    def unwrapped(self):
        return self

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(self, dim, heads, dropout):
        super().__init__()
        self.heads = heads
        head_dim = dim // heads
        self.scale = head_dim ** -0.5
        self.attn = None

        self.qkv = nn.Linear(dim, dim * 3)
        self.attn_drop = nn.Dropout(dropout)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(dropout)

    @property
    def unwrapped(self):
        return self

    def forward(self, x, mask=None):
        B, N, C = x.shape
        qkv = (
            self.qkv(x)
            .reshape(B, N, 3, self.heads, C // self.heads)
            .permute(2, 0, 3, 1, 4)
        )
        q, k, v = (
            qkv[0],
            qkv[1],
            qkv[2],
        )
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x, attn


class Block(nn.Module):
    def __init__(self, dim, heads, mlp_dim, dropout, drop_path):
        super().__init__()

        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.attn = Attention(dim, heads, dropout)
        self.mlp = FeedForward(dim, mlp_dim, dropout)
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

    def forward(self, x, mask=None, return_attention=False):
        y, attn = self.attn(self.norm1(x), mask)
        if return_attention:
            return attn
        x = x + self.drop_path(y)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

def init_weights(m):
    if isinstance(m, nn.Linear):
        trunc_normal_(m.weight, std=0.02)
        if isinstance(m, nn.Linear) and m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.LayerNorm):
        nn.init.constant_(m.bias, 0)
        nn.init.constant_(m.weight, 1.0)


def resize_pos_embed(posemb, grid_old_shape, grid_new_shape, num_extra_tokens):
    # Rescale the grid of position embeddings when loading from state_dict. Adapted from
    # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
    posemb_tok, posemb_grid = (
        posemb[:, :num_extra_tokens],
        posemb[0, num_extra_tokens:],
    )
    if grid_old_shape is None:
        gs_old_h = int(math.sqrt(len(posemb_grid)))
        gs_old_w = gs_old_h
    else:
        gs_old_h, gs_old_w = grid_old_shape

    gs_h, gs_w = grid_new_shape
    posemb_grid = posemb_grid.reshape(1, gs_old_h, gs_old_w, -1).permute(0, 3, 1, 2)
    posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
    posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
    posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
    return posemb

class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, embed_dim, channels):
        super().__init__()

        self.image_size = image_size
        if image_size[0] % patch_size != 0 or image_size[1] % patch_size != 0:
            raise ValueError("image dimensions must be divisible by the patch size")
        self.grid_size = image_size[0] // patch_size, image_size[1] // patch_size
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.patch_size = patch_size
        self.proj = nn.Conv2d(
            channels, embed_dim, kernel_size=patch_size, stride=patch_size
        )

    def forward(self, im):
        B, C, H, W = im.shape
        x = self.proj(im)
        x= x.flatten(2).transpose(1, 2)
        return x


def unpadding(y, target_size):
    H, W = target_size
    H_pad, W_pad = y.size(2), y.size(3)
    # crop predictions on extra pixels coming from padding
    extra_h = H_pad - H
    extra_w = W_pad - W
    if extra_h > 0:
        y = y[:, :, :-extra_h]
    if extra_w > 0:
        y = y[:, :, :, :-extra_w]
    return y


class VisionTransformer(nn.Module):
    def __init__(
        self,
        image_size,
        patch_size,
        n_layers,
        d_model,
        d_ff,
        n_heads,
        n_cls,
        dropout=0.1,
        drop_path_rate=0.0,
        distilled=False,
        channels=3
    ):
        super().__init__()
        self.patch_embed = PatchEmbedding(
            image_size,
            patch_size,
            d_model,
            channels,
        )
        self.patch_size = patch_size
        self.n_layers = n_layers
        self.d_model = d_model
        self.d_ff = d_ff
        self.n_heads = n_heads
        self.dropout = nn.Dropout(dropout)
        self.n_cls = n_cls

        # cls and pos tokens
        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
        self.distilled = distilled
        if self.distilled:
            self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model))
            self.pos_embed = nn.Parameter(
                torch.randn(1, self.patch_embed.num_patches + 2, d_model)
            )
            self.head_dist = nn.Linear(d_model, n_cls)
        else:
            self.pos_embed = nn.Parameter(
                torch.randn(1, self.patch_embed.num_patches + 1, d_model)
            )

        # transformer blocks
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)]
        self.blocks = nn.ModuleList(
            [Block(d_model, n_heads, d_ff, dropout, dpr[i]) for i in range(n_layers)]
        )

        # output head
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, n_cls)
        

        trunc_normal_(self.pos_embed, std=0.02)
        trunc_normal_(self.cls_token, std=0.02)
        if self.distilled:
            trunc_normal_(self.dist_token, std=0.02)
        self.pre_logits = nn.Identity()
        self.decoder = DecoderLinear(19,16,192)
        self.apply(init_weights)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {"pos_embed", "cls_token", "dist_token"}

    @torch.jit.ignore()
    def load_pretrained(self, checkpoint_path, prefix=""):
        _load_weights(self, checkpoint_path, prefix)

    def forward(self, im, return_features=False):
        #print("ViT -> im shape", im.shape)
        B, T, H, W = im.shape
        PS = self.patch_size

        x = self.patch_embed(im)
        #print("ViT -> after patch embed operation", x.shape)
        cls_tokens = self.cls_token.expand(B, -1,-1)
        #print("ViT -> cls tokens", cls_tokens.shape)
        if self.distilled:
            dist_tokens = self.dist_token.expand(B, -1, -1)
            x = torch.cat((cls_tokens, dist_tokens, x), dim=1)
        else:
            x = torch.cat((cls_tokens, x), dim=1)

        pos_embed = self.pos_embed
        #print("ViT -> pos_embed", pos_embed.shape)
        num_extra_tokens = 1 + self.distilled
        if x.shape[1] != pos_embed.shape[1]:
            pos_embed = resize_pos_embed(
                pos_embed,
                self.patch_embed.grid_size,
                (H // PS, W // PS),
                num_extra_tokens,
            )
        x = x + pos_embed
        #print("ViT -> after pos embedding", x.shape)
        x = self.dropout(x)

        for blk in self.blocks:
            x = blk(x)
        #x = self.norm(x)
        
        """x = x[:, 1:]
        
        masks = self.decoder(x, (H, W))
        
        masks = F.interpolate(masks, size=(H, W), mode="bilinear")
        masks = unpadding(masks, (768, 768))"""

        #print("ViT Last step of encoder", x.shape)
        return x
        #return masks

    def get_attention_map(self, im, layer_id):
        if layer_id >= self.n_layers or layer_id < 0:
            raise ValueError(
                f"Provided layer_id: {layer_id} is not valid. 0 <= {layer_id} < {self.n_layers}."
            )
        B, _, H, W = im.shape
        PS = self.patch_size

        x = self.patch_embed(im)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        if self.distilled:
            dist_tokens = self.dist_token.expand(B, -1, -1)
            x = torch.cat((cls_tokens, dist_tokens, x), dim=1)
        else:
            x = torch.cat((cls_tokens, x), dim=1)

        pos_embed = self.pos_embed
        num_extra_tokens = 1 + self.distilled
        if x.shape[1] != pos_embed.shape[1]:
            pos_embed = resize_pos_embed(
                pos_embed,
                self.patch_embed.grid_size,
                (H // PS, W // PS),
                num_extra_tokens,
            )
        x = x + pos_embed

        for i, blk in enumerate(self.blocks):
            if i < layer_id:
                x = blk(x)
            else:
                return blk(x, return_attention=True)
            
class DecoderLinear(nn.Module):
    def __init__(self, n_cls, patch_size, d_encoder):
        super().__init__()

        self.d_encoder = d_encoder
        self.patch_size = patch_size
        self.n_cls = n_cls

        self.head = nn.Linear(self.d_encoder, n_cls)
        self.apply(init_weights)

    @torch.jit.ignore
    def no_weight_decay(self):
        return set()

    def forward(self, x, im_size):
        H, W = im_size
        GS = H // self.patch_size
        x = self.head(x)
        x = rearrange(x, "b (h w) c -> b c h w", h=GS)

        return x

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Vit(nn.Module):
    def __init__(self,image_size,patch_size,n_layers,d_model,d_ff,
        n_heads,
        n_cls,
        dropout=0.1,
        drop_path_rate=0.0,
        distilled=False,
        channels=3):
        
        self.image_size = image_size
        self.patch_size = patch_size
        self.n_layers =  n_layers
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_cls = n_cls
        self.d_ff = d_ff
        super().__init__()
        
        self.vit_encoder = VisionTransformer(image_size, patch_size,n_layers,d_model,n_cls,n_heads,d_ff)
                
        
    def forward(self,images):
        
        num_seqs = images.shape[1]
                
        images = rearrange(images, "b t c h w -> t b c h w")
        
        vit_transformer = []
        for i in range(num_seqs):
            x=self.vit_encoder(images[i])   #-> images[i]-> B*C*H*W
            vit_transformer.append(x)
            
        return torch.stack(vit_transformer)
            
class Predictor(nn.Module):
    def __init__(self,n_layers,d_model,d_ff,
        n_heads,
        dropout=0.1,
        drop_path_rate=0.0,
        distilled=False,
        channels=3):
        
        self.n_layers =  n_layers
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_ff = d_ff
        super().__init__()
        
        self.temporal_token = nn.Parameter(torch.randn(1,4,1,d_model))
        
        dpr = [x.item() for x in torch.linspace(0, 0.0, 12)]
        self.blocks = nn.ModuleList(
            [Block(d_model, n_heads, d_ff, dropout, dpr[i]) for i in range(n_layers)])
        
        self.norm = nn.LayerNorm(d_model)
        
    def forward(self,images, x):
        
        b,t,c,h,w = images.shape
        
        num_seqs = images.shape[1]
        
        #images=images.reshape(4,1,3,768,768)
        
        images = rearrange(images, "b t c h w -> t b c h w")
        
        vit_encoderoutput=[]
        
        predictor_embs=[]
        
        for i in range(num_seqs):
            print("------------------------------------------------------------------------")
            print("For image", i,"in sequence")
            #x=self.vit_encoder(images[i])   #-> images[i]-> B*C*H*W
            if(i==0):
                vit_encoderoutput.append(x[i])
                vit_embds = torch.stack(vit_encoderoutput)
            else:
                
                vit_embds = rearrange(vit_embds,"t b n d -> b t n d")
                #vit_embds = vit_embds.reshape(1,4,2305, 192)
                self.temporal_token = nn.Parameter(torch.randn(1,i,1,d_model))

                print("Temporal token",self.temporal_token.shape)
                #print(vit_embds.shape)

                vit_embds = self.temporal_token + vit_embds

                vit_embds = rearrange(vit_embds, "b t n d -> b (t n) d")
                #vit_embds = vit_embds.reshape(1,4*2305,192)
                print("Spatio-temporal Transformer input shape ->", vit_embds.shape)
                for blk in self.blocks:
                    vit_embds = blk(vit_embds)

                #Saving the output from the spatio-temporal transformer to be used for the corrector transformer
                spatemp_embds = vit_embds

                print("Spatial-transformer output embeddinngs",spatemp_embds.shape)

                spatemp_embds = rearrange(spatemp_embds , "b (t n) d  -> b t n d",t=i)

                #Discarding one embedding

                spatemp_embds = spatemp_embds[:,:1]
                
                predictor_embs.append(spatemp_embds)
                
                vit_encoderoutput.append(x[i])
                vit_embds = torch.stack(vit_encoderoutput)
                
        return predictor_embs
    
    
    
class Corrector(nn.Module):
    def __init__(self,n_layers,d_model,d_ff,
        n_heads,
        dropout=0.1,
        drop_path_rate=0.0,
        distilled=False,
        channels=3):
        
        self.n_layers =  n_layers
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_ff = d_ff
        super().__init__()
                
        dpr = [x.item() for x in torch.linspace(0, 0.0, 12)]
        self.blocks = nn.ModuleList(
            [Block(d_model, n_heads, d_ff, dropout, dpr[i]) for i in range(n_layers)])
        
        self.norm = nn.LayerNorm(d_model)
        
    def forward(self,images, x, predictor_embds):

        b,t,c,h,w = images.shape
        
        num_seqs = images.shape[1]
        
        #images=images.reshape(4,1,3,768,768)
        
        images = rearrange(images, "b t c h w -> t b c h w")
                        
        corrector_embds = []
        
        for i in range(num_seqs-1):
            print("------------------------------------------------------------------------")
            print("For image", i,"in sequence")
            #x=self.vit_encoder(images[i])   #-> images[i]-> B*C*H*W

            predictor_embds[i] = rearrange(predictor_embds[i], "b t n d -> b (t n) d")

            #Concatenating with vit_encoder output of current time step 

            corrector_input = torch.cat([predictor_embds[i] , x[i+1]], dim=1)

            print("Corrector transformer Input -> ", corrector_input.shape )

            #Passing through corrector transformer
            for blk in self.blocks:
                corrector_input = blk(corrector_input)

            #After passing through transformer
            print("After passing through corrector transformer", corrector_input.shape)

            #After rearranging
            corrector_output = rearrange(corrector_input, " b (t n) d -> b t n d", t=2)


            #Discardng one output
            corrector_output = corrector_output[:,:1]

            print("After discarding pne embedding from corrector transformer", corrector_output.shape)

            #Combining batch & time
            corrector_output = rearrange(corrector_output,"b t n d -> (b t) n d")                

            corrector_output = self.norm(corrector_output)
            
            corrector_embds.append(corrector_output)

        return corrector_embds

                
class Decoder(nn.Module):
    def __init__(self,image_size,patch_size,d_model,
        n_cls,
        distilled=False,
        channels=3):
        
        self.image_size = image_size
        self.patch_size = patch_size
        self.d_model = d_model
        self.n_cls = n_cls
        super().__init__()
        
        self.decoder = DecoderLinear(n_cls,patch_size,d_model)

    def forward(self, vit_embs, corrector_embds):
        h,w = self.image_size
        num_seqs = len(vit_embs)
        segmented_output=[]
        for i in range(num_seqs):
            print("------------------------------------------------------------------------")
            print("For image", i,"in sequence")
            if(i==0):
                decoder_input = vit_embs[i][:, 1:]
                decoder_output = self.decoder(decoder_input,(h,w))
                masks = F.interpolate(decoder_output, size=self.image_size, mode="bilinear")
                segmented_output.append(masks)
            else:
                #Decoder
                decoder_input = corrector_embds[i-1]
                print(decoder_input.shape)
                decoder_input = decoder_input[:, 1:]

                decoder_output = self.decoder(decoder_input,(h,w))

                masks = F.interpolate(decoder_output, size=self.image_size, mode="bilinear")
                segmented_output.append(masks)

                print("Final decoder output", masks.shape)
                
        return segmented_output
                


class ViViT(nn.Module):
    def __init__(self,image_size,patch_size,n_layers,d_model,d_ff,
        n_heads,
        n_cls,
        dropout=0.1,
        drop_path_rate=0.0,
        distilled=False,
        channels=3):
        self.image_size = image_size
        self.patch_size = patch_size
        self.n_layers =  n_layers
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_cls = n_cls
        self.d_ff = d_ff
        super().__init__()
        
        self.vit_embs = Vit(image_size, patch_size, n_layers, d_model,n_cls,n_heads, d_ff)
        self.predictor_embs = Predictor( n_layers, d_model, d_ff, n_heads)
        self.corrector_embs = Corrector( n_layers, d_model, d_ff, n_heads)
        self.decoder = Decoder(image_size, patch_size, d_model, d_ff, n_cls)
    
    def forward(self,x):
        vit_embds = self.vit_embs(x)
        
        predictor_embds = self.predictor_embs(x, vit_embds)
        
        corrector_embds = self.corrector_embs(x, vit_embds, predictor_embds)
        
        output = self.decoder(vit_embds,corrector_embds)
        
        return output, vit_embds, predictor_embds, corrector_embds
        

In [3]:
image_size = (768,768)
patch_size = 16
num_images = 4
n_layers = 12
d_model = 192
d_ff = 4*192
n_heads = 3
n_cls = 19
#model = VideoTransformer(image_size, patch_size, num_images, n_layers, d_model, d_ff, n_heads, n_cls, d_ff)

In [4]:
images = torch.rand(1,4,3,768,768)
model4 = ViViT(image_size, patch_size, n_layers, d_model, d_ff, n_heads, n_cls)
x4 = model4(images)

------------------------------------------------------------------------
For image 0 in sequence
------------------------------------------------------------------------
For image 1 in sequence
Temporal token torch.Size([1, 1, 1, 192])
Spatio-temporal Transformer input shape -> torch.Size([1, 2305, 192])
Spatial-transformer output embeddinngs torch.Size([1, 2305, 192])
------------------------------------------------------------------------
For image 2 in sequence
Temporal token torch.Size([1, 2, 1, 192])
Spatio-temporal Transformer input shape -> torch.Size([1, 4610, 192])
Spatial-transformer output embeddinngs torch.Size([1, 4610, 192])
------------------------------------------------------------------------
For image 3 in sequence
Temporal token torch.Size([1, 3, 1, 192])
Spatio-temporal Transformer input shape -> torch.Size([1, 6915, 192])
Spatial-transformer output embeddinngs torch.Size([1, 6915, 192])
------------------------------------------------------------------------
For i

In [5]:
def train_one_epoch(
    vit, 
    corrector,
    model,
    data_loader,
    optimizer,
    lr_scheduler,
    epoch,
    amp_autocast,
    loss_scaler,
):
    criterion1 = torch.nn.CrossEntropyLoss(ignore_index=IGNORE_LABEL)
    criterion2 = torch.nn.MSELoss(ignore_index=IGNORE_LABEL)
    logger = MetricLogger(delimiter="  ")
    header = f"Epoch: [{epoch}]"
    print_freq = 100

    model.train()
    data_loader.set_epoch(epoch)
    num_updates = epoch * len(data_loader)
    for batch in logger.log_every(data_loader, print_freq, header):
        im = batch["im"].to(ptu.device)
        seg_gt = batch["segmentation"].long().to(ptu.device)

        with amp_autocast():
            output, vit_embds, predictor_embds, corrector_embds = model.forward(im)
            loss1 = criterion2(corrector_embds, vit_embds)

            loss2 = criterion1(output,seg_gt)
            final_loss = loss1+loss2

        loss_value = final_loss
        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value), force=True)
        optimizer.zero_grad()
        if loss_scaler is not None:
            loss_scaler(
                loss,
                optimizer,
                parameters=model.parameters(),
            )
        else:
            loss_value.backward()
            optimizer.step()

        num_updates += 1
        lr_scheduler.step_update(num_updates=num_updates)

        torch.cuda.synchronize()

        logger.update(
            loss=loss.item(),
            learning_rate=optimizer.param_groups[0]["lr"],
        )

    return logger


In [9]:
#!/usr/bin/env python
# coding: utf-8

# In[ ]:


import sys
from pathlib import Path
import yaml
import json
import numpy as np
import torch
import click
import os
import argparse
from torch.nn.parallel import DistributedDataParallel as DDP
from factory import load_model
from segm_video.utils import distributed
import segm_video.utils.torch as ptu
from segm_video import config


import torch.nn as nn


from factory import create_segmenter
from segm_video.optim.factory import create_optimizer, create_scheduler
from segm_video.data.factory import create_dataset
from segm_video.model.utils import num_params

from timm.utils import NativeScaler
from contextlib import suppress

from segm_video.utils.distributed import sync_model
from segm_video.engine import train_one_epoch, evaluate

import wandb
%env DATASET = "/home/user/siddiquia0/dataset"

weight_decay = 0.0
scheduler = "polynomial"
optimizer = "sgd"
resume = False
eval_freq=None
epochs=None
normalization = None
amp = False
learning_rate =None
log_dir = "vivit"
dataset = "synpick"
decoder = "linear" 
backbone = "vit_tiny_patch16_384"
batch_size = None
# start distributed mode
ptu.set_gpu_mode(True)
#distributed.init_process()

# set up configuration
cfg = config.load_config()
model_cfg = cfg["model"][backbone]
dataset_cfg = cfg["dataset"][dataset]
if "mask_transformer" in decoder:
    decoder_cfg = cfg["decoder"]["mask_transformer"]
else:
    decoder_cfg = cfg["decoder"][decoder]

# model config
im_size = dataset_cfg["im_size"]
crop_size = dataset_cfg.get("crop_size", im_size)
window_size = dataset_cfg.get("window_size", im_size)
window_stride = dataset_cfg.get("window_stride", im_size)

model_cfg["image_size"] = (crop_size, crop_size)
model_cfg["backbone"] = backbone
model_cfg["dropout"] = 0.0
model_cfg["drop_path_rate"] = 0.1
decoder_cfg["name"] = decoder
model_cfg["decoder"] = decoder_cfg

# dataset config
world_batch_size = dataset_cfg["batch_size"]
num_epochs = dataset_cfg["epochs"]
lr = dataset_cfg["learning_rate"]
if batch_size:
    world_batch_size = batch_size
if epochs:
    num_epochs = epochs
if learning_rate:
    lr = learning_rate
if eval_freq is None:
    eval_freq = dataset_cfg.get("eval_freq", 1)

if normalization:
    model_cfg["normalization"] = normalization

# experiment config
batch_size = world_batch_size // ptu.world_size
variant = dict(
    world_batch_size=world_batch_size,
    version="normal",
    resume=resume,
    dataset_kwargs=dict(
        dataset=dataset,
        image_size=im_size,
        crop_size=crop_size,
        batch_size=batch_size,
        normalization=model_cfg["normalization"],
        split="train",
        num_workers=10,
    ),
    algorithm_kwargs=dict(
        batch_size=batch_size,
        start_epoch=0,
        num_epochs=num_epochs,
        eval_freq=eval_freq,
    ),
    optimizer_kwargs=dict(
        opt=optimizer,
        lr=lr,
        weight_decay=weight_decay,
        momentum=0.9,
        clip_grad=None,
        sched=scheduler,
        epochs=num_epochs,
        min_lr=1e-5,
        poly_power=0.9,
        poly_step_size=1,
    ),
    net_kwargs=model_cfg,
    amp=amp,
    log_dir=log_dir,
    inference_kwargs=dict(
        im_size=im_size,
        window_size=window_size,
        window_stride=window_stride,
    ),
)
#WANDB INITIALISATION
wandb.init(project="AIS-thesis", entity="aysha_athar")

log_dir = Path(log_dir)
log_dir.mkdir(parents=True, exist_ok=True)
checkpoint_path = log_dir / "checkpoint.pth"

# dataset
dataset_kwargs = variant["dataset_kwargs"]

train_loader = create_dataset(dataset_kwargs)
val_kwargs = dataset_kwargs.copy()
val_kwargs["split"] = "val"
val_kwargs["batch_size"] = 1
val_kwargs["crop"] = False
val_loader = create_dataset(val_kwargs)
n_cls = train_loader.unwrapped.n_cls

# model
net_kwargs = variant["net_kwargs"]
net_kwargs["n_cls"] = n_cls

model = model4
model.to(ptu.device)

#model, _ = load_model("seg_ade20k_tiny/checkpoint.pth")


#model.decoder.mask_norm = nn.LayerNorm(n_cls)
#model.decoder.cls_emb = nn.Parameter(torch.randn(1, n_cls, 192))
#model.n_cls=19
#model.to('cuda')
print(model) 

"""for name, param in model.named_parameters():
    param.requires_grad=False
    if(name=="decoder.mask_norm.weight" or name=="decoder.mask_norm.bias" or name=="decoder.cls_emb" ):
        param.requires_grad=True

for name,param in model.named_parameters():
    if(param.requires_grad==True):
        print(name)
        print("ok")"""

# optimizer
optimizer_kwargs = variant["optimizer_kwargs"]
optimizer_kwargs["iter_max"] = len(train_loader) * optimizer_kwargs["epochs"]
optimizer_kwargs["iter_warmup"] = 0.0
opt_args = argparse.Namespace()
opt_vars = vars(opt_args)
for k, v in optimizer_kwargs.items():
    opt_vars[k] = v
optimizer = torch.optim.SGD(model.parameters(), lr=0.001,momentum=0.9,nesterov=True)
lr_scheduler = create_scheduler(opt_args, optimizer)
num_iterations = 0
amp_autocast = suppress
loss_scaler = None
if amp:
    amp_autocast = torch.cuda.amp.autocast
    loss_scaler = NativeScaler()

# resume
if resume and checkpoint_path.exists():
    print(f"Resuming training from checkpoint: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    model.load_state_dict(checkpoint["model"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    if loss_scaler and "loss_scaler" in checkpoint:
        loss_scaler.load_state_dict(checkpoint["loss_scaler"])
    lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
    variant["algorithm_kwargs"]["start_epoch"] = checkpoint["epoch"] + 1
else:
    sync_model(log_dir, model)

"""if ptu.distributed:
    model = DDP(model, device_ids=[ptu.device], find_unused_parameters=True)"""

# save config
variant_str = yaml.dump(variant)
#print(f"Configuration:\n{variant_str}")
variant["net_kwargs"] = net_kwargs
variant["dataset_kwargs"] = dataset_kwargs
log_dir.mkdir(parents=True, exist_ok=True)
with open(log_dir / "variant.yml", "w") as f:
    f.write(variant_str)

# train
start_epoch = variant["algorithm_kwargs"]["start_epoch"]
num_epochs = variant["algorithm_kwargs"]["num_epochs"]
eval_freq = variant["algorithm_kwargs"]["eval_freq"]

model_without_ddp = model
if hasattr(model, "module"):
    model_without_ddp = model.module

val_seg_gt = val_loader.dataset.get_gt_seg_maps()

print(f"Train dataset length: {len(train_loader.dataset)}")
print(f"Val dataset length: {len(val_loader.dataset)}")
#print(f"Encoder parameters: {num_params(model_without_ddp.encoder)}")
print(f"Decoder parameters: {num_params(model_without_ddp.decoder)}")
filename_wandb=[]
for i in range(0,5):
    dict1=(next(iter(val_loader)))
    filename_wandb.append(dict1["im_metas"][0]["ori_filename"][0])

for epoch in range(start_epoch, num_epochs):
    # train for one epoch
    train_logger = train_one_epoch(
        model,
        train_loader,
        optimizer,
        lr_scheduler,
        epoch,
        amp_autocast,
        loss_scaler,
    )

    # save checkpoint
    if ptu.dist_rank == 0:
        snapshot = dict(
            model=model_without_ddp.state_dict(),
            optimizer=optimizer.state_dict(),
            n_cls=model_without_ddp.n_cls,
            lr_scheduler=lr_scheduler.state_dict(),
        )
        if loss_scaler is not None:
            snapshot["loss_scaler"] = loss_scaler.state_dict()
        snapshot["epoch"] = epoch
        torch.save(snapshot, checkpoint_path)

    # evaluate
    eval_epoch = epoch % eval_freq == 0 or epoch == num_epochs - 1
    if eval_epoch:
        eval_logger = evaluate(
            model,
            val_loader,
            val_seg_gt,
            window_size,
            window_stride,
            amp_autocast,
            filename_wandb,
            epoch,
        )
        print(f"Stats [{epoch}]:", eval_logger, flush=True)
        print("")
        

    # log stats
    if ptu.dist_rank == 0:
        train_stats = {
            k: meter.global_avg for k, meter in train_logger.meters.items()
        }
        val_stats = {}
        if eval_epoch:
            val_stats = {
                k: meter.global_avg for k, meter in eval_logger.meters.items()
            }
            print(val_stats['mean_iou'])
            print(type(val_stats))
            wandb.log({"Mean_iou": val_stats['mean_iou']})
        log_stats = {
            **{f"train_{k}": v for k, v in train_stats.items()},
            **{f"val_{k}": v for k, v in val_stats.items()},
            "epoch": epoch,
            "num_updates": (epoch + 1) * len(train_loader),
        }

        with open(log_dir / "log.txt", "a") as f:
            f.write(json.dumps(log_stats) + "\n")

sys.exit(1)






ModuleNotFoundError: No module named 'segm_video'