In [None]:
import os
import torch
import argparse
import itertools
import numpy as np
from tqdm import tqdm
from urllib.request import urlopen
from PIL import Image
import timm
from torch import Tensor
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
from glob import glob
from sklearn.model_selection import train_test_split
import pytorch_model_summary as tms
import torch.nn as nn
import random
from torch.nn.modules.batchnorm import _BatchNorm
import torchmetrics
import matplotlib.pyplot as plt
import torch.nn.functional as F
print(f"GPUs used:\t{torch.cuda.device_count()}")
device = torch.device("cuda",6)
print(f"Device:\t\t{device}")

In [None]:

params={'image_size':512,
        'lr':2e-4,
        'beta1':0.5,
        'beta2':0.999,
        'batch_size':1,
        'epochs':1000,
        'n_classes':2,
        'data_path':'../../data/NIA/BRNT/',
        'inch':3,
        }

In [None]:
trans = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

def transback(data:Tensor) -> Tensor:
    return data / 2 + 0.5

class CustomDataset(Dataset):
    """COCO Custom Dataset compatible with torch.utils.data.DataLoader."""
    def __init__(self,parmas,path):
        
        self.args=parmas

        self.path=path
    def trans(self,image):
        if random.random() > 0.5:
            transform = transforms.RandomHorizontalFlip(1)
            image = transform(image)
            
        if random.random() > 0.5:
            transform = transforms.RandomVerticalFlip(1)
            image = transform(image)
            
        return image
    
    def __getitem__(self, index):
        
        trans1 = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        image=trans1(Image.open(self.path[index]).convert('RGB').resize((params['image_size'],params['image_size'])))
        image = self.trans(image)
        path=self.path[index]
        return image,path
    
    def __len__(self):
        return len(self.path)



image_list=glob(params['data_path']+'*.jpeg')

train_dataset=CustomDataset(params,image_list)
dataloader=DataLoader(train_dataset,batch_size=params['batch_size'],shuffle=True)


In [None]:

class FeatureExtractor(nn.Module):
    """Feature extoractor block"""
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        cnn1= timm.create_model('tf_efficientnetv2_xl', pretrained=True)
        self.feature_ex = nn.Sequential(*list(cnn1.children())[:-1])

    def forward(self, inputs):
        features = self.feature_ex(inputs)
        
        return features
class custom_model(nn.Module):
    def __init__(self, num_classes, image_feature_dim,feature_extractor_scale1: FeatureExtractor):
        super(custom_model, self).__init__()
        self.num_classes = num_classes
        self.image_feature_dim = image_feature_dim

        # Remove the classification head of the CNN model
        self.feature_extractor = feature_extractor_scale1
        # Classification layer
        self.classification_layer = nn.Linear(image_feature_dim, num_classes)
        
    def forward(self, inputs):
        batch_size, channels, height, width = inputs.size()
        
        # Feature extraction using the pre-trained CNN
        features = self.feature_extractor(inputs)  # Shape: (batch_size, 2048, 1, 1)
        
        # Classification layer
        logits = self.classification_layer(features)  # Shape: (batch_size, num_classes)
        
        return logits
    
class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                self.state[p]["old_p"] = p.data.clone()
                e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.data = self.state[p]["old_p"]  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
        closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass

        self.first_step(zero_grad=True)
        closure()
        self.second_step()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
                    torch.stack([
                        ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm

    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups
        
def disable_running_stats(model):
    def _disable(module):
        if isinstance(module, _BatchNorm):
            module.backup_momentum = module.momentum
            module.momentum = 0

    model.apply(_disable)

def enable_running_stats(model):
    def _enable(module):
        if isinstance(module, _BatchNorm) and hasattr(module, "backup_momentum"):
            module.momentum = module.backup_momentum
            
import transformers

Feature_Extractor=FeatureExtractor()
model = custom_model(2,1280,Feature_Extractor)
model = model.to(device)
base_optimizer = torch.optim.SGD
optimizer = SAM(model.parameters(), base_optimizer, lr=params['lr'], momentum=0.9)
accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=2).to(device)
checkpoint = torch.load("../../model/detail_classification/BRNT/Eff_v2_XL_SAM_114.pt")
model.load_state_dict(checkpoint)

In [11]:
import shutil
val=tqdm(dataloader)
model.eval()
count=0
val_running_loss=0.0
acc_loss=0
C1_list=[]
C2_list=[]

with torch.no_grad():
    for x,path in val:
        count+=1
        x=x.to(device).float()
        predict = model(x).to(device)
        if predict.softmax(dim=1).argmax(dim=1).item()==0:
            source=path[0]
            destination='../../temp/BRNT/class1/'+os.path.basename(source) 
            shutil.copyfile(source, destination)

        else:
            source=path[0]
            destination='../../temp/BRNT/class2/'+os.path.basename(source) 
            shutil.copyfile(source, destination)
        val.set_description(f"Step: {count+1}")


Step: 3217:   2%|‚ñè         | 3215/147583 [04:51<3:39:24, 10.97it/s]

In [None]:

for i in range(len(C1_list)):
    source=C1_list[i]
    destination='../../temp/BRNT/class1/'+os.path.basename(C1_list[i]) 
    shutil.copyfile(source, destination)

for i in range(len(C2_list)):
    source=C2_list[i]
    destination='../../temp/BRNT/class2/'+os.path.basename(C2_list[i]) 
    shutil.copyfile(source, destination)
    
