In [None]:
"""
    Generate cat2imgkey from /kaggle/input/[selected_dataset]
    cat2imgkey : {"category":[imgkey1, imgkey2 ...]}
"""

import os

base_dir = "/kaggle/input/pandora-clean-straight-dataset-correct/360 CAC/query_images"
cats = os.listdir(base_dir)  
cat2imgkey = {}

for cat in cats:
    cat_path = os.path.join(base_dir, cat)
    if os.path.isdir(cat_path):
        imgkeys = [d for d in os.listdir(cat_path) if os.path.isdir(os.path.join(cat_path, d))]
        cat2imgkey[cat] = imgkeys


train_cat = ['computer', 'car', 'chair', 'window', 'book', 'cabinet']
val_cat = ['door', 'penguin', 'bottle']
test_cat = ['dog', 'hyaenidae', 'boat', 'person']



train_data_num = 0
val_data_num = 0
test_data_num = 0


for cat in train_cat:
  train_data_num += len(cat2imgkey[cat])

for cat in val_cat:
  val_data_num += len(cat2imgkey[cat])

for cat in test_cat:
  test_data_num += len(cat2imgkey[cat])

print(f"train data num : {train_data_num}")
print(f"val data num : {val_data_num}")
print(f"test data num : {test_data_num}")


In [None]:
"""
    Download CFOCNet background repository for the environment
"""

!git clone https://github.com/Allenchou0708/My-Class-agnostic-Few-shot-Object-Counting.git
!mv My-Class-agnostic-Few-shot-Object-Counting/* .

In [None]:
"""
    Download SSIM package
"""

!pip install pytorch_msssim

In [None]:
%%writefile ./model/resblocks.py

"""
    Revise the Resnet in CFOCNet for processing deform feature
"""


import torch
from torchvision.models.resnet import resnet50
import torch.nn as nn

def make_resblocks(deform_feature, data_pipe = "reference"):
    

    net = resnet50(pretrained=False)


    """
        If it is the query data pipe, 
        we replace the first conv with the conv accepted 4 channel
    """
    if data_pipe == "query" and deform_feature : 

        new_conv = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)  
        net.conv1 = new_conv  # Replace in ResNet
    
    
    layer0_name = ['conv1','bn1','relu','maxpool']
    layer1_name = ['layer1']
    layer2_name = ['layer2']
    layer3_name = ['layer3']

    layer0 = nn.Sequential()
    layer1 = nn.Sequential()
    layer2 = nn.Sequential()
    layer3 = nn.Sequential()

    for n,c in net.named_children():
        if n in layer0_name:
            layer0.add_module(n,c)
        elif n in layer1_name:
            layer1.add_module(n,c)
        elif n in layer2_name:
            layer2.add_module(n,c)
        elif n in layer3_name:
            layer3 = c
        else:
            break

    return layer0, layer1, layer2, layer3


In [None]:
%%writefile ./model/CFOCNet.py

"""
    Replace some module in the original CFOCNet
"""


import torch
import torch.nn as nn
from .layers import Self_Attn
from .utils import  JDimPool
from .resblocks import make_resblocks
import torch.nn.functional as F


class DropBlock2D(nn.Module):
    def __init__(self, block_size, drop_prob):
        super(DropBlock2D, self).__init__()
        self.block_size = block_size
        self.drop_prob = drop_prob

    def forward(self, x, training_mode):
        if training_mode != "train" or self.drop_prob == 0.0:
            return x

        # The probability of becoming the center of DropBlock
        # Divided the patch size in case of selecting the same pixel as the center of the DropBlock
        gamma = self.drop_prob * x.numel() / (self.block_size ** 2) / x[0].numel()

        # Select the center pixel (H * W)
        mask = (torch.rand(x.shape[0], *x.shape[2:], device=x.device) < gamma).float()

        # Produce the DropBlock mask by expand the center to block_size * block_size (H * W)
        mask = F.max_pool2d(mask.unsqueeze(1), kernel_size=self.block_size, stride=1, padding=self.block_size // 2)
        mask = 1 - mask.squeeze(1)
        drop_block_result = x * mask.unsqueeze(1) * (mask.numel() / mask.sum())
        return drop_block_result



class CFOCNet(nn.Module):

    def __init__(self, config):
        super(CFOCNet, self).__init__()

        self.config = config

        self.res_q0, self.res_q1,self.res_q2,self.res_q3 = make_resblocks(config.data.deform_feature, data_pipe = "query")
        self.res_r0, self.res_r1,self.res_r2,self.res_r3 = make_resblocks(config.data.deform_feature)

        self.sa_q1 = Self_Attn(256, nn.ReLU()) # channel attention
        self.sa_q2 = Self_Attn(512, nn.ReLU()) # channel attention
        self.sa_q3 = Self_Attn(1024, nn.ReLU()) # channel attention

        self.j_maxpool = JDimPool(7,1) # max pooling k reference image
        
        
        self.maxpool_r1 = nn.MaxPool2d(4, stride=4, padding=0)
        self.maxpool_r2 = nn.MaxPool2d(2, stride=2, padding=0)

        self.match_query_conv1 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1)
        self.match_query_conv2 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1)
        self.match_query_conv3 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1)

        self.match_reference_conv1 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1)
        self.match_reference_conv2 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1)
        self.match_reference_conv3 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1)


        self.transpose_convolution = nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=3, stride=2, padding=1, output_padding = 1)

        self.dropblock0 = DropBlock2D(block_size=7, drop_prob=self.config.train.dropblock_prop)

        self.fusion = nn.Sequential(
            nn.Conv2d(3, 8, kernel_size=1, padding=0),  
            nn.ReLU(),
            nn.Conv2d(8, 4, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.Conv2d(4, 1, kernel_size=1, padding=0)  
        )
        
        
    # transform the feature vector into unit vector
    def l2_normalize(self, x, dim=1, eps=1e-6):
        return x / (x.norm(p=2, dim=dim, keepdim=True) + eps)
    

    def forward(self, queury, references, training_mode="val"):

        # query pipeline

        q0 = self.res_q0(queury)

        if self.config.train.dropblock : 
            q0 = self.dropblock0(q0, training_mode)
        q1 = self.res_q1(q0)
        q2 = self.res_q2(q1)
        q3 = self.res_q3(q2)

        q1_mix, _ = self.sa_q1(q1)
        q2_mix, _ = self.sa_q2(q2)
        q3_mix, _ = self.sa_q3(q3)

        q1_norm_c = self.match_query_conv1(q1_mix)
        q2_norm_c = self.match_query_conv2(q2_mix)
        q3_norm_c = self.match_query_conv3(q3_mix)


        # reference pipeline

        org_ref_size = references.size()
        references = references.view(-1,org_ref_size[-3],org_ref_size[-2],org_ref_size[-1])

        r0 = self.res_r0(references)
        r1 = self.res_r1(r0)
        r2 = self.res_r2(r1)
        r3 = self.res_r3(r2)

        r1_size = r1.size()
        r2_size = r2.size()
        r3_size = r3.size()

        r1 = r1.view(org_ref_size[0],org_ref_size[1],r1_size[-3],r1_size[-2],r1_size[-1])
        r2 = r2.view(org_ref_size[0],org_ref_size[1],r2_size[-3],r2_size[-2],r2_size[-1])
        r3 = r3.view(org_ref_size[0],org_ref_size[1],r3_size[-3],r3_size[-2],r3_size[-1])

        r1_mix = self.j_maxpool(r1)
        r2_mix = self.j_maxpool(r2)
        r3_mix = self.j_maxpool(r3)


        kernel1_norm_size = self.maxpool_r1(r1_mix)
        kernel2_norm_size = self.maxpool_r2(r2_mix)
        kernel3_norm_size = r3_mix

        kernel1_norm_c = self.match_reference_conv1(kernel1_norm_size)
        kernel2_norm_c = self.match_reference_conv2(kernel2_norm_size)
        kernel3_norm_c = self.match_reference_conv3(kernel3_norm_size)



        # image matching
        
        M1 = []
        M2 = []
        M3 = []

        for i in range(kernel1_norm_c.size(0)):

            if self.config.train.feature_l2_norm:
                q1_feature, q2_feature, q3_feature = self.l2_normalize(q1_norm_c[i:i+1]), self.l2_normalize(q2_norm_c[i:i+1]), self.l2_normalize(q3_norm_c[i:i+1])
                k1, k2, k3 = self.l2_normalize(kernel1_norm_c[i:i+1]), self.l2_normalize(kernel2_norm_c[i:i+1]), self.l2_normalize(kernel3_norm_c[i:i+1])
            else : 
                q1_feature, q2_feature, q3_feature = q1_norm_c[i:i+1], q2_norm_c[i:i+1], q3_norm_c[i:i+1]
                k1, k2, k3 = kernel1_norm_c[i:i+1], kernel2_norm_c[i:i+1], kernel3_norm_c[i:i+1]

            
            sample_m1 = nn.functional.conv2d(q1_feature, k1, padding=1, stride=2)
            sample_m2 = nn.functional.conv2d(q2_feature, k2, padding=1, stride=2)
            sample_m3 = nn.functional.conv2d(q3_feature, k3, padding=1, stride=2)
            
            M1.append(sample_m1)
            M2.append(sample_m2)
            M3.append(sample_m3)

        M1 = torch.cat(M1,0)
        M2 = torch.cat(M2,0)
        M3 = torch.cat(M3,0)

        M2 = nn.functional.interpolate(M2,scale_factor=2)
        M3 = nn.functional.interpolate(M3,scale_factor=4)


        # fusion different granularity
        FS_gran = torch.cat([M1, M2, M3], dim=1)
        FS = self.fusion(FS_gran)

        FS = self.transpose_convolution(FS)

        FS = nn.functional.interpolate(FS, scale_factor=4, mode='bilinear', align_corners=True)

        return FS



In [None]:
%%writefile ./configs/config.yaml



train:
  epochs: 200
  batch_size: 8
  num_workers: 2
  result_path: /kaggle/input/pandora-clean-straight-dataset-correct/360 CAC

  ssim_loss: 1.0e-3
  focal_weight : 15

  feature_l2_norm : True # True/ False
  dropblock : True # True/ False
  dropblock_prop : 0.15

data:
  data_path: .
  num_references : 7
  augmentation_type : "soft" # "soft"/"hard"
  deform_feature : True # True/ False

optimizer:
  lr: 5.0e-5

eval:
  checkpoint: /home/Hacker_Davinci/Desktop/Open_Source/CFOCNet/ckpt/model.ckpt
  sample: True
  image_folder: /home/Hacker_Davinci/Desktop/Open_Source/CFOCNet/image_folder
  inference_time_test : False # True/ False
  show_predict_image_amount : 2



In [None]:
# Loss.py

import torch.nn as nn
from pytorch_msssim import ssim
import torch
import torch.nn.functional as F

class ObjectCountLoss : 

    def __init__(self, config):
        self.config = config
    

    def focal_huber_loss(self, pred, target, delta = 1):
    
        focal_weight = torch.where(target > 0.1, self.config.train.focal_weight, 1.0)
            
        error = pred - target
        abs_error = error.abs()
    
        large_error_num = (abs_error > delta).sum()
        small_error_num = (abs_error <= delta).sum()
    
    
        small_mask = (abs_error <= delta)
        large_mask = (abs_error > delta)
    
        small_focal_weight = focal_weight[small_mask]
        large_focal_weight = focal_weight[large_mask]
        
        small_loss = (error[small_mask] ** 2 + delta**2) / (2 * delta) * small_focal_weight
        large_loss = abs_error[large_mask] * large_focal_weight
    
        
        if large_error_num > 0:
            large_loss = large_loss.sum() 
        else:
            large_loss = torch.tensor(0.0, device=pred.device)
    
        if small_error_num > 0:
            small_loss = small_loss.sum()
        else:
            small_loss = torch.tensor(0.0, device=pred.device)
    
    
        loss = large_loss + small_loss
    
        
        return loss

    
    
    def compute_loss(self, predicted_density_map, ground_truth_density_map):
    
        Standard_L2_loss = self.focal_huber_loss(predicted_density_map, ground_truth_density_map)
        
        predicted_density_map = predicted_density_map.double()
        ground_truth_density_map = ground_truth_density_map.double()
        
        SSIM_loss = 1 * self.config.train.batch_size - ssim(predicted_density_map, ground_truth_density_map, data_range=1.0, size_average=True)
    
        
        Final_loss = Standard_L2_loss + self.config.train.ssim_loss * SSIM_loss
    
    
        return Final_loss, SSIM_loss






In [None]:
# Dataset

import torch
from torch.utils.data import Dataset
from torchvision import transforms
import numpy as np
import glob
import os
import cv2
import scipy.ndimage
import json
import random
from tqdm import tqdm
import pickle



cv2.setNumThreads(0) # disable multithread to avoid deadlocks

def add_latitude_channel(img_tensor):
    """
    img_tensor: shape (B, 3, H, W) and H=256, W=256
    return: (B, 4, H, W)
    """
    B, C, H, W = img_tensor.shape
    lat_values = torch.linspace(1, -1, steps=H).unsqueeze(1).repeat(1, W)  # (H, W)
    
    # broadcast to batch
    lat_channel = lat_values.unsqueeze(0).repeat(B, 1, 1, 1)  # (B, H, W) → (B, 1, H, W)

    # concat as 4th channel
    img_with_lat = torch.cat([img_tensor, lat_channel], dim=1)
    return img_with_lat



class CountingkDataset(Dataset):

    def __init__(self, config, dataset_type, cat2imgkey):
        super(CountingkDataset, self).__init__()
        self.config = config
        self.dataset_type = dataset_type
        self.cat2imgkey = cat2imgkey
        
        self.query_transforms = transforms.Compose([
            transforms.ToTensor(),
        ])

        self.reference_transform = transforms.Compose([
            transforms.ToTensor(),
        ])


        
        if dataset_type == "train":
          self.select_catIds = ['computer', 'car', 'chair', 'window', 'book', 'cabinet'] #354

        elif dataset_type == "val":
          self.select_catIds = ['door', 'penguin', 'bottle'] #53

        elif dataset_type == "test":
          self.select_catIds = ['dog', 'hyaenidae', 'boat', 'person']#103


        """
            Choose the data by training mode
        """
        
        cache_dir = f"/kaggle/working/cache_data/{dataset_type}"
        img2cat_path = os.path.join(cache_dir, "img2cat.pkl")
        select_imgIds_path = os.path.join(cache_dir, "select_imgIds.pkl")
        imgkey2reffnum_path = os.path.join(cache_dir, "imgkey2reffnum.pkl")
        truenum_path = os.path.join(cache_dir, "imgkey_imgnum2truenum.pkl")

        self.img2cat = {} 
        self.select_imgIds = [] 
        self.imgkey2reffnum = {}
        self.imgkey_imgnum2truenum = {}

        k = self.config.data.num_references


        if all(os.path.exists(p) for p in [img2cat_path, select_imgIds_path, imgkey2reffnum_path, truenum_path]):

            with open(img2cat_path, "rb") as f:
                self.img2cat = pickle.load(f)
            with open(select_imgIds_path, "rb") as f:
                self.select_imgIds = pickle.load(f)
            with open(imgkey2reffnum_path, "rb") as f:
                self.imgkey2reffnum = pickle.load(f)
            with open(truenum_path, "rb") as f:
                self.imgkey_imgnum2truenum = pickle.load(f)

            print("Dataset metadata loaded from cache.")


        else :


          for c in self.select_catIds:
            
              imgs = self.cat2imgkey[c]
              image_paths = glob.glob(f"/kaggle/input/pandora-clean-straight-dataset-correct/360 CAC/refference_images/{c}/*.png")
            
              refference_imgkey = []
            
              for path in image_paths:
                  filename = os.path.basename(path)
                  name_only = os.path.splitext(filename)[0]  
                  refference_imgkey.append(name_only)
              
              
              for i in imgs:
                  self.img2cat[i] = c
                  self.select_imgIds.append(i)

                  select_reference_imgIds = np.random.choice(refference_imgkey,k,True)
                  self.imgkey2reffnum[i] = select_reference_imgIds



          for idx, imgkey in tqdm(enumerate(self.select_imgIds)):
              selected_cat = self.img2cat[imgkey]
              dm_np = np.loadtxt(f"/kaggle/input/pandora-clean-straight-dataset-correct/360 CAC/density_maps/{selected_cat}/{imgkey}/0.txt")
              count = float(np.sum(dm_np))
              self.imgkey_imgnum2truenum[(imgkey, "0")] = count

          os.makedirs(cache_dir, exist_ok=True)
          
          with open(img2cat_path, "wb") as f:
              pickle.dump(self.img2cat, f)
          with open(select_imgIds_path, "wb") as f:
              pickle.dump(self.select_imgIds, f)
          with open(imgkey2reffnum_path, "wb") as f:
              pickle.dump(self.imgkey2reffnum, f)
          with open(truenum_path, "wb") as f:
              pickle.dump(self.imgkey_imgnum2truenum, f)


        
        self.length = len(self.select_imgIds)



    def __len__(self):
        return self.length



    def __getitem__(self, idx):

        imgkey = self.select_imgIds[idx]
        selected_cat = self.img2cat[imgkey]
        
        query_images_path = f"/kaggle/input/pandora-clean-straight-dataset-correct/360 CAC/query_images/{selected_cat}/{imgkey}/0.png"
        img = cv2.imread(query_images_path, cv2.IMREAD_COLOR)
        query_tensor = self.generate_query_tensor(img)

        references_tensor = self.generate_references_tensor(selected_cat, self.imgkey2reffnum[imgkey])

        density_maps_tensor = np.loadtxt(f"/kaggle/input/pandora-clean-straight-dataset-correct/360 CAC/density_maps/{selected_cat}/{imgkey}/0.txt")
        density_maps_tensor = self.query_transforms(density_maps_tensor)

        true_count = self.imgkey_imgnum2truenum.get((imgkey, "0"), 0.0)


        
        return query_tensor, references_tensor, density_maps_tensor, true_count, imgkey, selected_cat




    def generate_query_tensor(self, img):
        img_size = img.shape
        max_img_size = max(img_size[0],img_size[1])
        ph = int((max_img_size - img_size[0])/2)
        pw = int((max_img_size - img_size[1])/2)

        pad_img = np.pad(img,((ph,ph),(pw,pw),(0,0)))
        pad_img = cv2.cvtColor(pad_img,cv2.COLOR_BGR2RGB)
        query_tensor = self.query_transforms(pad_img)

        return query_tensor



    def generate_references_tensor(self,select_catId, select_reference_imgIds,k=5):

        references_tensor = []

        refference_path = f"/kaggle/input/pandora-clean-straight-dataset-correct/360 CAC/refference_images/{select_catId}"

        for ref_id in select_reference_imgIds:
            
            crop_img = cv2.imread(os.path.join(refference_path,f'{ref_id}.png'),cv2.IMREAD_COLOR)
            crop_img = cv2.cvtColor(crop_img,cv2.COLOR_BGR2RGB)
            crop_tensor = self.reference_transform(crop_img)

            references_tensor.append(crop_tensor)
            
        reference_concat = torch.stack(references_tensor, 0) # k reference image
        
        return reference_concat







In [None]:
# Data Augmentation

class DataAugmentor:
    def __init__(self, config):
        self.augmentation_type = config.data.augmentation_type
        self.deform_feature = config.data.deform_feature
        
        print(f"apply {self.augmentation_type} data augmentation")

    def apply_jitter(self, img, brightness_factor, contrast_factor, saturation_factor, hue_factor):
        img = T.functional.adjust_brightness(img, brightness_factor)
        img = T.functional.adjust_contrast(img, contrast_factor)
        img = T.functional.adjust_saturation(img, saturation_factor)
        img = T.functional.adjust_hue(img, hue_factor)
        return img

    # horizontal flip
    def flip(self, q_sample, r_sample, t_sample, direction = "horizontal"):
        if direction == "vertical" : 
            q_aug = torch.flip(q_sample, dims=[1])  # flip height
            r_aug = torch.flip(r_sample, dims=[2])
            t_aug = torch.flip(t_sample, dims=[1])
        else : 
            q_aug = torch.flip(q_sample, dims=[2])  # flip width
            r_aug = torch.flip(r_sample, dims=[3])
            t_aug = torch.flip(t_sample, dims=[2])
        return q_aug, r_aug, t_aug

    def rotation(self, q_sample, r_sample, t_sample, rotation_time=1):
        q_aug = q_sample.rot90(rotation_time, [1,2])
        r_aug = r_sample.rot90(rotation_time, [2,3])
        t_aug = t_sample.rot90(rotation_time, [1,2])
        return q_aug, r_aug, t_aug

    def color_jitter(self, q_sample, r_sample, t_sample):
        to_pil = T.ToPILImage()
        to_tensor = T.ToTensor()

        color_jitter = T.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.2)
        
        jitter_transform, brightness_factor, contrast_factor, saturation_factor, hue_factor = T.ColorJitter.get_params(
            brightness=color_jitter.brightness,
            contrast=color_jitter.contrast,
            saturation=color_jitter.saturation,
            hue=color_jitter.hue
        )

        
        # do the same color jitter for every sample 
        r_aug_list = [to_tensor(self.apply_jitter(to_pil(r), brightness_factor, contrast_factor, saturation_factor, hue_factor)) for r in r_sample]
        r_aug = torch.stack(r_aug_list, dim=0)

        if self.deform_feature : 
            deform = q_sample[3:, :, :]                   
            q_sample = q_sample[:3, :, :]
        q_aug = to_tensor(self.apply_jitter(to_pil(q_sample), brightness_factor, contrast_factor, saturation_factor, hue_factor))
        
        if self.deform_feature : 
            q_aug = torch.cat([q_aug, deform], dim=0)
        
        t_aug = t_sample  # target don't need color jitter

        return q_aug, r_aug, t_aug

    def gaussian_blur(self, q_sample, r_sample, t_sample):
        blur = T.GaussianBlur(kernel_size=(9,9), sigma=(2, 5))

        if self.deform_feature : 
            deform = q_sample[3:, :, :]    # B, 1, H, W
            q_sample = q_sample[:3, :, :]
        q_aug = blur(q_sample)

        if self.deform_feature : 
            q_aug = torch.cat([q_aug, deform], dim=0)
        
        r_aug = r_sample
        t_aug = t_sample

        return q_aug, r_aug, t_aug

    def motion_blur(self, q_sample, r_sample, t_sample):
        kernel_size = 9
        kernel = torch.ones(1, 1, 1, kernel_size) / kernel_size  
        
        Cq = q_sample.shape[2]

        if self.deform_feature : 
            deform = q_sample[3:, :, :]   
            q_sample = q_sample[:3, :, :]
        q = q_sample.permute(2, 0, 1).unsqueeze(0)  
        q_aug = torch.nn.functional.conv2d(q, kernel.expand(Cq, 1, 1, kernel_size), padding=(0, kernel_size//2), groups=Cq)
        q_aug = q_aug.squeeze(0).permute(1, 2, 0)
        if self.deform_feature : 
            q_aug = torch.cat([q_aug, deform], dim=0)
        
        
        t_aug = t_sample
        r_aug = r_sample
        
        return q_aug, r_aug, t_aug

    def apply_data_augmentation(self, queries, references, target):
        if self.augmentation_type == "soft":
            return self.apply_soft_data_augmentation(queries, references, target)
        else:
            return self.apply_hard_data_augmentation(queries, references, target)
    

    def apply_soft_data_augmentation(self, queries, references, target):

        B, N, H, W, C = references.shape
        augmented_queries = []
        augmented_references = []
        augmented_target = []

        aug_type_index = -1
        aug_type_title_list = ["original", "horizontal flip", "vertical flip", "rotation 90", "rotation 180", "rotation 270", "color jitter", "gaussian blur", "motion blur "]

        for i in range(B):
            q_sample = queries[i]  
            r_sample = references[i]  
            t_sample = target[i]  

            aug_type = random.randint(0, 8)
                

            # keep original
            if aug_type == 0:
                q_aug, r_aug, t_aug = q_sample, r_sample, t_sample

            # horizontal flip
            elif aug_type == 1:
                q_aug, r_aug, t_aug = self.flip(q_sample, r_sample, t_sample, direction = "horizontal")

            # vertical flip
            elif aug_type == 2:
                q_aug, r_aug, t_aug = self.flip(q_sample, r_sample, t_sample, direction = "vertical")

            # rotation 90
            elif aug_type == 3:
                q_aug, r_aug, t_aug = self.rotation(q_sample, r_sample, t_sample, rotation_time=1)

            # rotation 180
            elif aug_type == 4:
                q_aug, r_aug, t_aug = self.rotation(q_sample, r_sample, t_sample, rotation_time=2)

            # rotation 270
            elif aug_type == 5:
                q_aug, r_aug, t_aug = self.rotation(q_sample, r_sample, t_sample, rotation_time=3)

            # color jitter
            elif aug_type == 6:
                q_aug, r_aug, t_aug = self.color_jitter(q_sample, r_sample, t_sample)

            
            # gaussian blur
            elif aug_type == 7:
                q_aug, r_aug, t_aug = self.gaussian_blur(q_sample, r_sample, t_sample)
            
            # motion blur 
            elif aug_type == 8:
               q_aug, r_aug, t_aug = self.motion_blur(q_sample, r_sample, t_sample)
                

            augmented_queries.append(q_aug)
            augmented_references.append(r_aug)
            augmented_target.append(t_aug)

        queries_aug = torch.stack(augmented_queries, dim=0)
        references_aug = torch.stack(augmented_references, dim=0)
        target_aug = torch.stack(augmented_target, dim=0)


        return queries_aug, references_aug, target_aug


    
    def apply_hard_data_augmentation(self, queries, references, target, debug=False):
        
        B, N, H, W, C = references.shape
        augmented_queries = []
        augmented_references = []
        augmented_target = []

        i0_whether_h_flip = -1
        i0_whether_v_flip = -1
        i0_whether_rotation = -1
        i0_whether_color_jitter = -1
        i0_whether_blur = -1

        
        aug_type_title_list = [["original", "horizontal flip"], ["original", "vertical flip"], ["original", "rotation 90", "rotation 180", "rotation 270"], ["original", "color jitter"], ["original","gaussian blur", "motion blur"]]

        
        for i in range(B):
            q_sample = queries[i]  
            r_sample = references[i] 
            t_sample = target[i]  

            whether_h_flip = random.randint(0, 1)
            whether_v_flip = random.randint(0, 1)
            whether_rotation = random.randint(0, 4)
            whether_color_jitter = random.randint(0, 1)
            whether_blur = random.randint(0, 3)


            if i == 0 :
                i0_whether_h_flip = whether_h_flip
                i0_whether_v_flip = whether_v_flip
                i0_whether_rotation = whether_rotation
                i0_whether_color_jitter = whether_color_jitter
                i0_whether_blur = whether_blur

            # horizontal flip
            if whether_h_flip == 1:
                q_sample, r_sample, t_sample = self.flip(q_sample, r_sample, t_sample, direction = "horizontal")

            # vertical flip
            if whether_v_flip == 1:
                q_sample, r_sample, t_sample = self.flip(q_sample, r_sample, t_sample, direction = "vertical")


            # rotation 90
            if whether_rotation == 1:
                q_sample, r_sample, t_sample = self.rotation(q_sample, r_sample, t_sample, rotation_time=1)

            # rotation 180
            elif whether_rotation == 2:
                q_sample, r_sample, t_sample = self.rotation(q_sample, r_sample, t_sample, rotation_time=2)

            # rotation 270
            elif whether_rotation == 3:
                q_sample, r_sample, t_sample = self.rotation(q_sample, r_sample, t_sample, rotation_time=3)

            # color jitter
            if whether_color_jitter == 1:
                q_sample, r_sample, t_sample = self.color_jitter(q_sample, r_sample, t_sample)
                

            # gaussian blur
            if whether_blur == 1:
                q_sample, r_sample, t_sample = self.gaussian_blur(q_sample, r_sample, t_sample)

            # motion blur 
            elif whether_blur == 2:
                q_sample, r_sample, t_sample = self.motion_blur(q_sample, r_sample, t_sample)
            

            augmented_queries.append(q_sample)
            augmented_references.append(r_sample)
            augmented_target.append(t_sample)

        queries_aug = torch.stack(augmented_queries, dim=0)
        references_aug = torch.stack(augmented_references, dim=0)
        target_aug = torch.stack(augmented_target, dim=0)
        

        return queries_aug, references_aug, target_aug

In [None]:
# Traininer

import importlib
import os
import numpy as np
from torch.utils.data import DataLoader
import torch.nn as nn
from torch import optim
import torch
import logging
import matplotlib.pyplot as plt
from torchvision.utils import save_image, make_grid
import model.CFOCNet
importlib.reload(model.CFOCNet)
from model.CFOCNet import CFOCNet
import model.resblocks
importlib.reload(model.resblocks)
from tqdm import tqdm
from PIL import Image
import csv
from datetime import datetime
from torchvision.transforms import transforms
from collections import deque
import copy
import torchvision.transforms as T
import torch.nn.functional as F
import time




class Runner:
    def __init__(self,args,config,logger):
        self.args= args
        self.config= config
        self.logger = logger

    def make_model_filename(self, global_step):
        now = datetime.now().strftime("%m_%d_%H_%M")
        return f"{now}_{global_step}_cfocnet.pt"

    def save_models(self, model, optimizer, global_step, save_dir="/kaggle/working/models"):
      os.makedirs(save_dir, exist_ok=True)
      fname = self.make_model_filename(global_step)
      path = os.path.join(save_dir, fname)
      torch.save(
          {
              "model_state" : model.state_dict(),
              "optim_state" : optimizer.state_dict()
          }, path
      )

      return fname

    def load_models(self,  model, model_name, optimizer="", save_dir="/kaggle/input/09_02_03_49_1400_cfocnet/pytorch/default/1"):

      checkpoint = torch.load(os.path.join(save_dir, f"{model_name}_cfocnet.pt"))
      model.load_state_dict(checkpoint["model_state"])
      if optimizer != "" :
          optimizer.load_state_dict(checkpoint["optim_state"])


    def show_predicted_output(self, queries, target, FS, evaluation_mode):
      print(f"{evaluation_mode} output : ")

      if self.config.data.deform_feature : 
          img1 = queries[0][:3, :, :].detach().cpu().numpy().transpose(1,2,0)
      else :
          img1 = queries[0].detach().cpu().numpy().transpose(1,2,0)
      img2 = target[0].detach().cpu().numpy().transpose(1,2,0)
      img3 = F.relu(FS[0]).detach().cpu().numpy().transpose(1,2,0)
      show_images  = [img1, img2, img3]
      
      fig, axes = plt.subplots(1, 3, figsize=(12, 4))

      for ax, img in zip(axes, show_images):
          ax.imshow(img)
          ax.axis('off')  
      
      plt.tight_layout()
      plt.show()

      

    def evaluate_by_metrics(self, true_count, FS):

        count_num_list = [round(this_true_count.item(), 2) for this_true_count in true_count]    
        pred_object_num = [round(torch.sum(F.relu(FS_i)).item(), 2) for FS_i in FS]

        count_num_list = np.array(count_num_list)
        pred_object_num = np.array(pred_object_num)

        this_mae = np.sum(np.abs(count_num_list - pred_object_num))
        this_mse = np.sum((count_num_list - pred_object_num) **2)
        this_nae = np.sum(np.abs(count_num_list - pred_object_num) / count_num_list)
        this_sre = np.sum(((count_num_list - pred_object_num) **2) / count_num_list)

        return this_mae, this_mse, this_nae, this_sre


    
    def train(self, cat2imgkey, selected_model_name, selected_step = 1 , val_step=500, store_step = 1000, valid_val_step=5000):
        tqdm._instances.clear()
      
        # Import the dataset
        dataset = CountingkDataset(self.config,'train', cat2imgkey) 
        data_loader = DataLoader(dataset,self.config.train.batch_size, shuffle=True, pin_memory=True)
        print("dataset length ", len(dataset))
        print(f"1 epoch = {len(dataset)//self.config.train.batch_size + 1} batch")
        data_augmentor = DataAugmentor(self.config)
        

        net = CFOCNet(self.config)
        net.to(self.config.device)
        optimizer = optim.Adam(net.parameters(),lr=self.config.optimizer.lr)
        
        if selected_model_name != "":
          self.load_models(net, selected_model_name, optimizer=optimizer, save_dir=f"/kaggle/input/{selected_model_name}_cfocnet/pytorch/default/1") 
        

        global_step = 1


        best_nae = 1000000000000000000
        best_sre = 1000000000000000000

        my_loss = ObjectCountLoss(self.config)

        total_steps = self.config.train.epochs * len(data_loader)
        pbar = tqdm(total=total_steps, desc="Training", unit="step")


        for epoch in range(1, self.config.train.epochs + 1):
            
            
            for sample in data_loader:

                if global_step < selected_step :
                    pbar.set_description(f"Epoch {epoch}")
                    pbar.update(1)
                    global_step += 1
                    continue

                optimizer.zero_grad()
                net.train()
                
                queries = sample[0] # B * Idx * H * W * C
                references = sample[1] # B * H * W * C
                target = sample[2] # B * Idx * H * W * C
                true_count = sample[3].to(self.config.device) # B * Idx * H * W * C
                imgkey = sample[4]
                selected_cat = sample[5]


                B, N = queries.shape[0], queries.shape[1]

                if self.config.data.deform_feature : 
                    queries = add_latitude_channel(queries)
                
                queries, references, target = data_augmentor.apply_data_augmentation(queries, references, target)
                
                queries = queries.to(self.config.device)
                references = references.to(self.config.device)
                target = target.to(self.config.device)
                    
                
                FS = net(queries, references, "train")

                
                Final_loss, SSIM_loss = my_loss.compute_loss(FS, target)
                


                self.logger.info(f"Step : {global_step} Loss: {Final_loss.item()}")

                
                Final_loss.backward()
                torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=10.0)

                optimizer.step()


                
                

                if global_step >= valid_val_step  and  global_step % val_step == 0 :

                  this_mae, this_mse, this_nae, this_sre = self.evaluate_by_metrics(true_count, FS)

                  self.show_predicted_output(queries, target, FS, "train")
                    
                  b_size = len(FS)
                  print(f"imgkey : {imgkey[0]}")
                  print(f"count category : {selected_cat[0]}")
                  print(f'nae = {this_nae/b_size}')
                  print(f'sre = {this_sre/b_size}')
                  print("\n")


                  val_nae, val_sre = self.test("val", cat2imgkey, net= net, whether_show = True)
                  
                  if (val_nae < best_nae) and (val_sre < best_sre):
                      model_names = self.save_models(net, optimizer, global_step)
                      best_nae = val_nae
                      best_sre = val_sre
                      print(f"we have best nae : {val_nae} and sre: {val_sre}")
            

                if global_step % store_step == 0 :
                    model_names = self.save_models(net, optimizer, global_step)

                pbar.set_description(f"Epoch {epoch}")
                pbar.update(1)


                global_step+=1



    
    def test(self, evaluation_mode, cat2imgkey, net= "", selected_model_name="", whether_show=False, inference_time_test=False):

        if selected_model_name != "":
          net = CFOCNet(self.config)
          self.load_models(net, selected_model_name, save_dir=f"/kaggle/input/{selected_model_name}_cfocnet/pytorch/default/1") 

        
        net.to(self.config.device)
        net.eval()

        
        group_dataset = CountingkDataset(self.config, evaluation_mode, cat2imgkey) 
        data_loader = DataLoader(group_dataset, self.config.train.batch_size)


        mae_sum, mse_sum, nae_sum, sre_sum = 0, 0, 0, 0


        count = len(group_dataset)
        

        plt_query_images = []
        plt_target_images = []
        plt_predict_images = []

        start_time = time.time()

        with torch.no_grad():
            for i, sample in enumerate(data_loader):

                queries = sample[0]
                references = sample[1].to(self.config.device) # B * H * W * C
                target = sample[2].to(self.config.device) # B * Idx * H * W * C
                true_count = sample[3].to(self.config.device) # B * Idx * H * W * C
                imgkey = sample[4]
                selected_cat = sample[5]
                


                if self.config.data.deform_feature : 
                    queries = add_latitude_channel(queries)
                
                queries = queries.to(self.config.device)


                
                FS = net(queries, references)

                this_mae, this_mse, this_nae, this_sre = self.evaluate_by_metrics(true_count, FS)


                mae_sum += this_mae
                mse_sum += this_mse
                nae_sum += this_nae
                sre_sum += this_sre

                
                if i < self.config.eval.show_predict_image_amount : # select some sample and show predict image
                    self.show_predicted_output(queries, target, FS, evaluation_mode)
                    
                    print(f"imgkey : {imgkey[0]}")
                    print(f"count category : {selected_cat[0]}")
                    print(f'this nae = {this_nae}')
                    print(f'this sre = {this_sre}')  

                    print("\n")
                  
                  



        end_time = time.time()

        if self.config.eval.inference_time_test : 
            total_time = end_time - start_time
            total_images = len(group_dataset)
            avg_time_per_image = total_time / total_images
            print(f"Average time per image : {avg_time_per_image:.6f} sec ({1/avg_time_per_image:.2f} FPS)")


        # print(f'mae = {mae_sum/count}')
        # print(f'mse = {mse_sum/count}')
        print(f'nae = {nae_sum/count}')
        print(f'sre = {sre_sum/count}')


        return nae_sum/count, sre_sum/count
        





In [None]:
# Main Function

import torch
import logging
import os
import sys
import shutil
import numpy as np
import argparse
import yaml 
import copy
from datetime import datetime
import sys



def parse_args_and_config():

    # args

    parser = argparse.ArgumentParser(description=globals()['__doc__'])

    parser.add_argument('--config', type=str, required=True,  help='Path to the config file')
    parser.add_argument('--seed', type=int, default=0, help='Random seed')
    parser.add_argument('--exp', type=str, default='exp', help='Path for saving running related data.')
    parser.add_argument('--doc', type=str, required=True, help='A string for documentation purpose. '
                                                               'Will be the name of the log folder.')
    parser.add_argument('--comment', type=str, default='', help='A string for experiment comment')
    parser.add_argument('--verbose', type=str, default='info', help='Verbose level: info | debug | warning | critical')
    parser.add_argument('--train', action='store_true', help='Whether to train the model')
    parser.add_argument('--val', action='store_true', help='Whether to val the model')
    parser.add_argument('--test', action='store_true', help='Whether to test the model')
    parser.add_argument('--sample', action='store_true', help='Whether to produce samples from the model')
    parser.add_argument('--resume_training', action='store_true', help='Whether to resume training')
    parser.add_argument('-i', '--image_folder', type=str, default='images', help="The folder name of samples")

    args = parser.parse_args()

    args.log_path = os.path.join(args.exp, 'logs', args.doc)

    # set random seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
    

    # config
    with open(os.path.join('configs', args.config), 'r') as f:
        config = yaml.load(f,yaml.CLoader)
    new_config = dict2namespace(config)

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    new_config.device = device
    
    
    os.makedirs("/kaggle/working/logging_files", exist_ok=True)


    now = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file = f"/kaggle/working/logging_files/log_file_{now}.txt"
    logger = logging.getLogger("train_logger")
    logger.setLevel(logging.INFO)
    file_handler = logging.FileHandler(log_file)
    file_handler.setFormatter(logging.Formatter('%(message)s'))
    logger.addHandler(file_handler)
  

    # add device
    logging.info("Using device: {}".format(device))


    return args, new_config, logger


"""
    we can use config.class.subclass to get config
"""
def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace




def main(cat2imgkey, selected_model_name, selected_step):

    sys.argv = [
      '',  
      '--config', 'config.yaml',
      '--doc', 'traininglog',
      '--train',
      # '--val',
      # '--test',
    ]


    args, config, logger = parse_args_and_config()
    
    config_dict = copy.copy(vars(config)) # vars
    
    print(config)

    runner = Runner(args,config,logger)
    if args.test:
        runner.test("test", cat2imgkey, selected_model_name = selected_model_name)
    elif args.val:
        runner.test("val", cat2imgkey, selected_model_name = selected_model_name)
    elif args.train:
        runner.train(cat2imgkey, selected_model_name, selected_step=selected_step, val_step=500, store_step = 1000, valid_val_step=500)
    else:
        print("Add --test or --train in your command line or run.sh !")




if __name__ == '__main__':
    try:
        selected_model_name = "" # model name like : 10_17_06_57_10 #if you want to train the model from scratch, set ""
        selected_step = 1 # if you want to train the model from scratch, set 1
        main(cat2imgkey, selected_model_name, selected_step)
    except Exception as e:
        print(f"Error : {e}")
