In [1]:
import torch
import torchvision.transforms as transforms
import timm
from tqdm import tqdm
from torch.utils.data import Dataset, Subset, DataLoader, Sampler
import glob
import os
from PIL import Image
from collections import defaultdict, Counter
import numpy as np
import random
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, accuracy_score, recall_score, f1_score, confusion_matrix, roc_auc_score
import seaborn as sns
from sklearn.manifold import TSNE
import json
import math
import time

os.environ["LOKY_MAX_CPU_COUNT"] = "6"
os.environ['TORCH_USE_CUDA_DSA'] = "1"

SEED = 42

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from datasets import load_dataset, DatasetDict
from huggingface_hub import HfApi, HfFolder

cifar_dataset = load_dataset('uoft-cs/cifar10')
cifar_dataset

# mnist_dataset = load_dataset("mnist", trust_remote_code=True) 
# mnist_dataset

DatasetDict({
    train: Dataset({
        features: ['img', 'label'],
        num_rows: 50000
    })
    test: Dataset({
        features: ['img', 'label'],
        num_rows: 10000
    })
})

In [3]:
# imagenet_base_dir = r'C:\datasets\tiny-imagenet-200'

# train_ds = load_dataset("imagefolder", data_dir=os.path.join(imagenet_base_dir, 'train_preprocess'))["train"]
# test_ds  = load_dataset("imagefolder", data_dir=os.path.join(imagenet_base_dir, 'valid_preprocess'))['validation']

# imagenet_dataset = DatasetDict({
#     "train": train_ds,
#     "test": test_ds
# })

# imagenet_dataset

In [4]:
train_transform = transforms.Compose([
    # transforms.Resize((224, 224)),
    # transforms.RandomHorizontalFlip(p=0.5),
    # transforms.RandomRotation(10),
    # transforms.ColorJitter(0.1, 0.1, 0.1),
    # transforms.Grayscale(3),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], 
    #                      std=[0.2470, 0.2435, 0.2616])
])

valid_transform = transforms.Compose([
    # transforms.Resize((224, 224)),
    # transforms.Grayscale(3),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], 
    #                      std=[0.2470, 0.2435, 0.2616])
])

In [5]:
class custom_dataset(Dataset):
    def __init__(self, vanila_dataset, image_indexes, transform):
        self.dataset = vanila_dataset
        self.image_indexes = image_indexes

        self.transform = transform

    def __len__(self):
        return len(self.image_indexes)
    
    def __getitem__(self, index):
        selected_index = self.image_indexes[index]
        
        _image = self.dataset[selected_index]['img']
        if isinstance(_image, str):
            _image = Image.open(_image).convert('RGB')

        _image = self.transform(_image)

        _label = self.dataset[selected_index]['label']

        return _image, _label, selected_index

In [6]:
def worker_init_fn(worker_id):
    seed = worker_id + SEED

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    
    return

In [7]:
class custom_data_factory():
    def __init__(self, train_dataset, test_dataset, train_transform, valid_transform, sample_num:int=None, normalize_method:str="base", batch_sizes:int=None, save_path:str=None):
        self.train_data_dict = {}

        for i in range(0, train_dataset.num_rows):
            _index = i
            _label = train_dataset[i]['label']
            if self.train_data_dict.get(_label) is None:
                self.train_data_dict[_label] = []

            self.train_data_dict[_label].append(_index)

        self.train_datset = train_dataset
        self.test_dataset = test_dataset

        self.sample_num = sample_num
        self.selected_items = {_label: random.choices(item_index, k=min(len(item_index), self.sample_num)) for _label, item_index in self.train_data_dict.items()}
        
        assert normalize_method in ['base', 'softmax'], 'Normalize method error, methods should be one of ["base", "softmax"]'
        self.normalize_method = normalize_method
        self.save_path = save_path if save_path else None
        self.batch_sizes = batch_sizes if batch_sizes else 128

        self.train_transform = train_transform
        self.valid_transform = valid_transform

        self.losses = []

    def get_data_set(self):
        _item_list =  sum(self.selected_items.values(), [])
        random.shuffle(_item_list)
        
        self.train_set = custom_dataset(self.train_datset, _item_list, self.train_transform)
        self.valid_set = custom_dataset(self.test_dataset, [i for i in range(0, 10000)], self.valid_transform)

        # print(f'Train set: {len(self.train_set.image_indexes)}')
        # print(f'Valid set: {len(self.valid_set.image_indexes)}')

        return self.train_set, self.valid_set
    
    def get_data_loader(self):
        train_loader = DataLoader(
            self.train_set, pin_memory=True,
            batch_size=self.batch_sizes,
            num_workers=0,
            drop_last=True
        )

        valid_loader = DataLoader(
            self.valid_set, pin_memory=True,
            batch_size=self.batch_sizes,
            num_workers=0,
            drop_last=True
        )
        
        # self.save_data_state()

        return train_loader, valid_loader
    
    def collect_losses(self, _paths, _labels, _losses):
        assert len(_paths) == len(_labels), 'Losses information error'
        assert len(_paths) == len(_losses), 'Losses information error'
        
        for _path, _label, _loss in zip(_paths, _labels, _losses):
            self.losses.append((_path.item(), _label.item(), _loss.item()))

    def save_data_state(self):
        if not self.save_path:
            return
        
        _save_path = os.path.join(os.getcwd(), self.save_path)
        if os.path.exists(_save_path):
            with open(_save_path, 'r', encoding='utf-8-sig') as json_file:
                saved_states = json.load(json_file)
            saved_states.append(self.selected_items)

        else:
            os.makedirs(os.path.dirname(_save_path), exist_ok=True)
            saved_states = [self.selected_items]
        
        with open(_save_path, 'w', encoding='utf-8-sig') as json_file:
            json.dump(saved_states, json_file, ensure_ascii=False, indent=2)

    def renew_data_loader(self, resample_ratio):
        start_time = time.time()

        selected_items_dict = {}
        for _path, _label, _loss in self.losses:
            if selected_items_dict.get(_label) == None:
                selected_items_dict[_label] = []

            selected_items_dict[_label].append((_path, _loss))
        
        selected_items = {}
        for _label, _path_list in selected_items_dict.items():
            _paths, _weights = zip(*_path_list)
            if max(_weights) - min(_weights):
                if self.normalize_method == 'base':
                    _weights = [(_weight - min(_weights))/(max(_weights)-min(_weights) + 1e-8) for _weight in _weights]
                
                elif self.normalize_method == 'softmax':
                    _weights = np.array(_weights)
                    _weights = np.exp(_weights - np.max(_weights))
                    _weights = _weights / np.sum(_weights)

            else:
               _weights = [1 / len(_weights)] * len(_weights)

            selected_items[_label] = random.choices(_paths,
                                                    k=int(self.sample_num * (1-resample_ratio)),
                                                    weights=_weights)

        unselected_items = {_label: random.sample([_path for _path in _path_list if _path not in selected_items_dict[_label]],
                                                  k=min(int(self.sample_num * resample_ratio), len(self.train_data_dict[_label])))
                                                  for _label, _path_list in self.train_data_dict.items()}

        self.selected_items = {}
        for (_label, _selected_path_list), (_label, _unselected_path_list) in zip(selected_items.items(), unselected_items.items()):
            self.selected_items[_label] = []
            self.selected_items[_label].extend(_selected_path_list)
            self.selected_items[_label].extend(_unselected_path_list)
            random.shuffle(self.selected_items[_label])

        # print(self.selected_items)

        self.losses = []

        self.get_data_set()
        
        end_time = time.time()

        return self.get_data_loader(), end_time - start_time

In [8]:
# selected_items_dict = {}
# for _path, _label, _loss in datafactory.losses:
#     if selected_items_dict.get(_label) == None:
#         selected_items_dict[_label] = []

#     selected_items_dict[_label].append((_path, _loss))

# selected_items_dict

In [9]:
def evaluate(model, valid_loader):
    model.eval()
    all_probs, all_labels = [], []
    
    with torch.no_grad():
        for images, labels, selected_indexes in tqdm(valid_loader, desc=f'Validating', leave=False):
            images, labels = images.to('cuda'), labels.to('cuda')
            
            outputs = model(images)

            probs = F.softmax(outputs, dim=1)

            all_probs.append(probs.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())   
    
    all_probs_roc = np.concatenate(all_probs, axis=0)
    all_preds = np.argmax(all_probs_roc, axis=1)
    all_labels = np.array(all_labels)

    return {
        'accuracy': accuracy_score(all_labels, all_preds),
        'precision': precision_score(all_labels, all_preds, average='weighted', zero_division=0),
        'recall': recall_score(all_labels, all_preds, average='weighted', zero_division=0),
        'f1': f1_score(all_labels, all_preds, average='weighted', zero_division=0),
        'roc_auc':roc_auc_score(all_labels, all_probs_roc, multi_class='ovr'),
        # 'confusion_matrix': confusion_matrix(all_labels, all_probs)
    }

In [10]:
from dataclasses import dataclass
from dataclasses import asdict
import yaml

@dataclass
class config:
    epoch:int = 1000
    lr:float = 0.01
    
    model_name:str = 'resnet50'
    trained_model:bool = True

    resample_ratio:float = 0.7
    batch_size:int = 128
    object_dataset:str = 'cifar'
    normalize_method:str = 'base'


In [11]:
# base_model = timm.create_model(config.model_name, pretrained=config.trained_model, num_classes=200).to('cpu')
# torch.save(timm.create_model('resnet50', pretrained=True, num_classes=200), os.path.join(os.getcwd(), 'trained_base_model_200.pt'))
# torch.save(timm.create_model('resnet50', pretrained=False, num_classes=200), os.path.join(os.getcwd(), 'untrained_base_model_200.pt'))

In [12]:
yaml_path = os.path.join(os.getcwd(), 'results', config.object_dataset, 'config.yaml')
os.makedirs(os.path.dirname(yaml_path), exist_ok=True)

with open(yaml_path, 'w') as yaml_file:
    yaml.dump(asdict(config()), yaml_file)


# for sample_nums in [100, 125, 170, 250]:
for sample_nums in [1000, 1500, 2000, 3000]:
    datafactory = custom_data_factory(
        eval(f"{config.object_dataset}_dataset['train']"),
        eval(f"{config.object_dataset}_dataset['test']"),
        sample_num=sample_nums,
        train_transform=train_transform,
        valid_transform=valid_transform,
        batch_sizes=config.batch_size,
        save_path=rf'results\{config.object_dataset}\dataset_history.json'
    )
 
    model = torch.load(os.path.join(os.getcwd(), 'trained_base_model.pt'),  weights_only=False) if config.trained_model else \
        torch.load(os.path.join(os.getcwd(), 'untrained_base_model.pt'),  weights_only=False)

    model.to('cuda')
    optimizer = torch.optim.SGD(model.parameters(), lr=config.lr)

    train_history = {}
    criterion = nn.CrossEntropyLoss(reduction='none')

    datafactory.get_data_set()
    train_loader, valid_loader = datafactory.get_data_loader()
    start_time = time.time()

    for epoch in range(config.epoch):
        model.train()
        total_losses = []

        for images, labels, selected_index in tqdm(train_loader, desc=f'Training epoch: {epoch+1:>3}/{config.epoch}', leave=False):
            images = images.to('cuda')
            labels = labels.to('cuda')

            optimizer.zero_grad()
            
            outputs = model(images)

            loss = criterion(outputs, labels)
            
            loss.mean().backward()
            optimizer.step()

            total_losses.append(loss.mean().item())
            
            datafactory.collect_losses(selected_index, labels.cpu(), loss.cpu())

        valid_results = evaluate(model, valid_loader=valid_loader)
        valid_results['loss'] = np.mean(total_losses)
        valid_results['time'] = [time.time() - start_time]
        train_history[epoch] = valid_results

        (train_loader, valid_loader), spend_time = datafactory.renew_data_loader(resample_ratio=config.resample_ratio)
        valid_results['time'].append(spend_time)

        json_path = os.path.join(os.getcwd(), rf'results/{config.object_dataset}/collect_losses_train_history_{sample_nums}.json')
        os.makedirs(os.path.dirname(json_path), exist_ok=True)
 
        with open(json_path, 'w', encoding='utf-8-sig') as json_file: 
            json.dump(train_history, json_file, indent=2, ensure_ascii=False)

                                                                                

In [15]:
model = torch.load(os.path.join(os.getcwd(), 'trained_base_model.pt'),  weights_only=False) if config.trained_model else \
    torch.load(os.path.join(os.getcwd(), 'untrained_base_model.pt'),  weights_only=False)

model.to('cuda')

train_history = {}
criterion = nn.CrossEntropyLoss(reduction='none')
optimizer = torch.optim.SGD(model.parameters(), lr=config.lr)

train_set = custom_dataset(eval(f"{config.object_dataset}_dataset['train']"), [i for i in range(0, 50000)], train_transform)
valid_set = custom_dataset(eval(f"{config.object_dataset}_dataset['test']"), [i for i in range(0, 10000)], valid_transform)

train_loader = DataLoader(
    train_set, pin_memory=True,
    batch_size=config.batch_size,
    num_workers=0,
    drop_last=True
)

valid_loader = DataLoader(
    valid_set, pin_memory=True,
    batch_size=config.batch_size,
    num_workers=0,
    drop_last=True
)

start_time = time.time()
for epoch in range(config.epoch):
    model.train()
    total_losses = []

    for images, labels, selected_index in tqdm(train_loader, desc=f'Training epoch: {epoch+1:>3}/{config.epoch}', leave=False):
        images = images.to('cuda')
        labels = labels.to('cuda')

        optimizer.zero_grad()
        
        outputs = model(images)

        loss = criterion(outputs, labels)
        
        loss.mean().backward()
        optimizer.step()

        total_losses.append(loss.mean().item())
        
    valid_results = evaluate(model, valid_loader=valid_loader)
    valid_results['loss'] = np.mean(total_losses)
    valid_results['time'] = time.time() - start_time
    train_history[epoch] = valid_results

    json_path = os.path.join(os.getcwd(), rf'results/{config.object_dataset}/base_history.json')
    os.makedirs(os.path.dirname(json_path), exist_ok=True)

    with open(json_path, 'w', encoding='utf-8-sig') as json_file:
        json.dump(train_history, json_file, indent=2, ensure_ascii=False)

                                                                            