In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from PIL import Image 
import torchvision.transforms as transforms 
import matplotlib.pyplot as plt
import seaborn as sns
import os
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

Prepare data

In [3]:

def get_sub_path(path):
    sub_path = []
    if isinstance(path, list):
        for p in path:
            if os.path.isdir(p):
                for file in os.listdir(p):
                    sub_path.append(os.path.join(p, file))
            else:
                continue
    else:
        for file in os.listdir(path):
            sub_path.append(os.path.join(path, file))
    return sub_path

def divide_list(list, n):
    for i in range(0, len(list), n):
        yield list[i:i + n]
        
def std(input):
    if input.max() == 0:
        return input
    else:
        result = (input-input.min()) / (input.max()-input.min())
        return result

In [4]:
feature_list = ['IR_drop_features_decompressed/power_i', 'IR_drop_features_decompressed/power_s', 
        'IR_drop_features_decompressed/power_sca', 'IR_drop_features_decompressed/power_all']
label_list = ['IR_drop_features_decompressed/IR_drop']

datapath = './CircuitNet-N28/'
# datapath = '../../CircuitNet/CircuitNet-N28/'
name_list = get_sub_path(os.path.join(datapath, feature_list[-1]))
n_list = divide_list(name_list, 1000)

In [5]:
class PowerDataset(Dataset):
    def __init__(self, root_dir, target_size=(224, 224)):
        self.root_dir = root_dir
        self.feature_dirs = ['power_i', 'power_s', 'power_sca', 'Power_all']
        self.label_dir = 'IR_drop'
        self.target_size = target_size
        # Collect all the feature and label file paths
        self.data = []
        i=0
        for case_name in os.listdir(os.path.join(root_dir, self.feature_dirs[0])):
            feature_paths = [os.path.join(root_dir, feature_dir, case_name) for feature_dir in self.feature_dirs]
            label_path = os.path.join(root_dir, self.label_dir, case_name)
            if all(os.path.exists(fp) for fp in feature_paths) and os.path.exists(label_path):
                self.data.append((feature_paths, label_path))
            i+=1
            if i>100:
                break

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        feature_paths, label_path = self.data[idx]
        features = []
        
        for fp in feature_paths:
            feature = np.load(fp)
            feature = torch.tensor(feature, dtype=torch.float32)
            feature = F.interpolate(feature.unsqueeze(0).unsqueeze(0), size=self.target_size, mode='nearest').squeeze(0).squeeze(0)
            feature = std(feature)
            features.append(feature)
            
        features = torch.stack(features, dim=0)
        
        # Load and process label file
        label = np.load(label_path)
        label = torch.tensor(label, dtype=torch.float32)
        label = F.interpolate(label.unsqueeze(0).unsqueeze(0), size=self.target_size, mode='nearest').squeeze(0).squeeze(0)
        label = label.clamp(1e-6, 50)
        label = (torch.log10(label)+6)/ (np.log10(50)+6)
        
        return features, label
    
root_dir = './CircuitNet-N28/IR_drop_features_decompressed/'
dataset = PowerDataset(root_dir)

In [6]:
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)


for features, labels in dataloader:
    print(features.shape, labels.shape)
    break

torch.Size([4, 4, 224, 224]) torch.Size([4, 224, 224])
torch.Size([4, 4, 224, 224]) torch.Size([4, 224, 224])
torch.Size([4, 4, 224, 224]) torch.Size([4, 224, 224])
torch.Size([4, 4, 224, 224]) torch.Size([4, 224, 224])
torch.Size([4, 4, 224, 224]) torch.Size([4, 224, 224])
torch.Size([4, 4, 224, 224]) torch.Size([4, 224, 224])
torch.Size([4, 4, 224, 224]) torch.Size([4, 224, 224])
torch.Size([4, 4, 224, 224]) torch.Size([4, 224, 224])
torch.Size([4, 4, 224, 224]) torch.Size([4, 224, 224])
torch.Size([4, 4, 224, 224]) torch.Size([4, 224, 224])
torch.Size([4, 4, 224, 224]) torch.Size([4, 224, 224])
torch.Size([4, 4, 224, 224]) torch.Size([4, 224, 224])
torch.Size([4, 4, 224, 224]) torch.Size([4, 224, 224])
torch.Size([4, 4, 224, 224]) torch.Size([4, 224, 224])
torch.Size([4, 4, 224, 224]) torch.Size([4, 224, 224])
torch.Size([4, 4, 224, 224]) torch.Size([4, 224, 224])
torch.Size([4, 4, 224, 224]) torch.Size([4, 224, 224])
torch.Size([4, 4, 224, 224]) torch.Size([4, 224, 224])
torch.Size

In [7]:
from swintransformer import *

model_name = 'swin_base_patch4_window7_224'
model = init_model(model_name, input_channels=4, num_classes=0, pretrained=True)

In [8]:

# model.eval()
# with torch.no_grad():
#     ir_prediction = model(features)

In [9]:
# ir_prediction.shape

torch.Size([1, 1, 224, 224])

In [16]:
import utils.losses as losses
import torch.optim as optim
from tqdm import tqdm
from math import cos, pi
model.train()

def checkpoint(model, epoch, save_path):
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    model_out_path = f"./{save_path}/swinTransformer_iters_{epoch}.pth"
    torch.save({'state_dict': model.state_dict()}, model_out_path)
    print("Checkpoint saved to {}".format(model_out_path))

class CosineRestartLr(object):
    def __init__(self,
                 base_lr,
                 periods,
                 restart_weights = [1],
                 min_lr = None,
                 min_lr_ratio = None):
        self.periods = periods
        self.min_lr = min_lr
        self.min_lr_ratio = min_lr_ratio
        self.restart_weights = restart_weights
        super().__init__()

        self.cumulative_periods = [
            sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
        ]

        self.base_lr = base_lr

    def annealing_cos(self, start: float,
                    end: float,
                    factor: float,
                    weight: float = 1.) -> float:
        cos_out = cos(pi * factor) + 1
        return end + 0.5 * weight * (start - end) * cos_out

    def get_position_from_periods(self, iteration: int, cumulative_periods):
        for i, period in enumerate(cumulative_periods):
            if iteration < period:
                return i
        raise ValueError(f'Current iteration {iteration} exceeds '
                        f'cumulative_periods {cumulative_periods}')


    def get_lr(self, iter_num, base_lr: float):
        target_lr = self.min_lr  # type:ignore

        idx = self.get_position_from_periods(iter_num, self.cumulative_periods)
        current_weight = self.restart_weights[idx]
        nearest_restart = 0 if idx == 0 else self.cumulative_periods[idx - 1]
        current_periods = self.periods[idx]

        alpha = min((iter_num - nearest_restart) / current_periods, 1)
        return self.annealing_cos(base_lr, target_lr, alpha, current_weight)

    
    def _set_lr(self, optimizer, lr_groups):
        if isinstance(optimizer, dict):
            for k, optim in optimizer.items():
                for param_group, lr in zip(optim.param_groups, lr_groups[k]):
                    param_group['lr'] = lr
        else:
            for param_group, lr in zip(optimizer.param_groups,
                                        lr_groups):
                param_group['lr'] = lr

    def get_regular_lr(self, iter_num):
        return [self.get_lr(iter_num, _base_lr) for _base_lr in self.base_lr]  # iters

    def set_init_lr(self, optimizer):
        for group in optimizer.param_groups:  # type: ignore
            group.setdefault('initial_lr', group['lr'])
            self.base_lr = [group['initial_lr'] for group in optimizer.param_groups  # type: ignore
        ]


# Build loss
loss = losses.__dict__['L1Loss']()

arg_dict = {'task': 'irdrop_mavi', 'save_path': 'work_dir/irdrop_mavi/', 'pretrained': None, 'max_iters':200, 'plot_roc': False, 'arg_file': None, 'cpu': True, 'dataroot': 'CircuitNet-N28/training_set/IR_drop', 'ann_file_train': './files/train_N28.csv', 'ann_file_test': './files/test_N28.csv', 'dataset_type': 'IRDropDataset', 'batch_size': 2, 'model_type': 'MAVI', 'in_channels': 1, 'out_channels': 4, 'lr': 0.0002, 'weight_decay': 0.01, 'loss_type': 'L1Loss', 'eval_metric': ['NRMS', 'SSIM'], 'threshold': 0.9885, 'ann_file': './files/train_N28.csv', 'test_mode': False}

# Build Optimzer
optimizer = optim.AdamW(model.parameters(), lr=arg_dict['lr'],  betas=(0.9, 0.999), weight_decay=arg_dict['weight_decay'])

# Build lr scheduler
cosine_lr = CosineRestartLr(arg_dict['lr'], [arg_dict['max_iters']], [1], 1e-7)
cosine_lr.set_init_lr(optimizer)

epoch_loss = 0
iter_num = 0
print_freq = 100
# save_freq = 10000
save_freq = 1000

while iter_num < arg_dict['max_iters']:
    with tqdm(total=print_freq) as bar:
        # for feature, label, _ in dataset:     
        for feature, label in dataloader:   
            if arg_dict['cpu']:
                input, target = feature, label
            else:
                input, target = feature.cuda(), label.cuda()

            regular_lr = cosine_lr.get_regular_lr(iter_num)
            cosine_lr._set_lr(optimizer, regular_lr)

            prediction = model(input)
            # print(input.shape)

            optimizer.zero_grad()
            prediction = prediction.squeeze(1)
            pixel_loss = loss(prediction, target)

            epoch_loss += pixel_loss.item()
            pixel_loss.backward()
            optimizer.step()

            iter_num += 1
            
            bar.update(1)

            if iter_num % print_freq == 0:
                break

    print("===> Iters[{}]({}/{}): Loss: {:.4f}".format(iter_num, iter_num, arg_dict['max_iters'], epoch_loss / print_freq))
    if iter_num % save_freq == 0:
        checkpoint(model, iter_num, arg_dict['save_path'])
    epoch_loss = 0

 26%|██▌       | 26/100 [00:38<01:48,  1.47s/it]


===> Iters[26](26/200): Loss: 5.3629


 26%|██▌       | 26/100 [00:36<01:43,  1.40s/it]


===> Iters[52](52/200): Loss: 5.3117


 26%|██▌       | 26/100 [00:36<01:43,  1.40s/it]


===> Iters[78](78/200): Loss: 5.2747


 22%|██▏       | 22/100 [00:31<01:51,  1.43s/it]


===> Iters[100](100/200): Loss: 4.4725


 26%|██▌       | 26/100 [00:36<01:44,  1.41s/it]


===> Iters[126](126/200): Loss: 5.2859


 26%|██▌       | 26/100 [00:36<01:44,  1.41s/it]


===> Iters[152](152/200): Loss: 5.2657


 26%|██▌       | 26/100 [00:36<01:44,  1.41s/it]


===> Iters[178](178/200): Loss: 5.3117


 22%|██▏       | 22/100 [00:32<01:53,  1.46s/it]

===> Iters[200](200/200): Loss: 4.4787





In [5]:
import timm

avail_pretrained_models = timm.list_models()

if 'swin_base_patch4_window7_224' in avail_pretrained_models:
    print("'swin_base_patch4_window7_224' is available in pretrained models.")
else:
    print("'swin_base_patch4_window7_224' is not available in pretrained models.")

'swin_base_patch4_window7_224' is available in pretrained models.


In [4]:
swin_models = [model for model in avail_pretrained_models if 'swin' in model]
print("Available Swin models:", swin_models)

Available Swin models: ['swin_base_patch4_window7_224.ms_in1k', 'swin_base_patch4_window7_224.ms_in22k', 'swin_base_patch4_window7_224.ms_in22k_ft_in1k', 'swin_base_patch4_window12_384.ms_in1k', 'swin_base_patch4_window12_384.ms_in22k', 'swin_base_patch4_window12_384.ms_in22k_ft_in1k', 'swin_large_patch4_window7_224.ms_in22k', 'swin_large_patch4_window7_224.ms_in22k_ft_in1k', 'swin_large_patch4_window12_384.ms_in22k', 'swin_large_patch4_window12_384.ms_in22k_ft_in1k', 'swin_s3_base_224.ms_in1k', 'swin_s3_small_224.ms_in1k', 'swin_s3_tiny_224.ms_in1k', 'swin_small_patch4_window7_224.ms_in1k', 'swin_small_patch4_window7_224.ms_in22k', 'swin_small_patch4_window7_224.ms_in22k_ft_in1k', 'swin_tiny_patch4_window7_224.ms_in1k', 'swin_tiny_patch4_window7_224.ms_in22k', 'swin_tiny_patch4_window7_224.ms_in22k_ft_in1k', 'swinv2_base_window8_256.ms_in1k', 'swinv2_base_window12_192.ms_in22k', 'swinv2_base_window12to16_192to256.ms_in22k_ft_in1k', 'swinv2_base_window12to24_192to384.ms_in22k_ft_in1k',