In [39]:
from datasets import load_dataset, concatenate_datasets
from transformers import TrainingArguments, Trainer
from transformers import AutoImageProcessor, ResNetForImageClassification, ResNetConfig
import evaluate
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from sklearn.metrics import accuracy_score

# Training the Task Ensemble

In [None]:
# Download & process dataset
dataset = load_dataset("Maysee/tiny-imagenet")

processor = AutoImageProcessor.from_pretrained("preprocessor_config.json",)

def process_example(example):
    if example['image'].mode != 'RGB':
        example['image'] = example['image'].convert('RGB')
    example = processor(example['image'], return_tensors='pt')
    example['pixel_values'] = example['pixel_values'].squeeze()
    return example

dataset['valid'] = dataset['valid'].map(process_example)
dataset['valid'].set_format("pt", columns=['pixel_values'], output_all_columns=True)

In [None]:
tasks = {
    1: [182, 61, 120, 193, 12, 23, 146, 165, 142, 171, 9, 45, 50, 192, 123, 156, 31, 89, 100, 65, 5, 75,
        157, 158, 139, 154, 35, 67, 58, 105, 29, 17, 150, 122, 15, 62, 167, 174, 60, 110, 133, 145, 199],
    2: [114, 169, 40, 18, 19, 49, 187, 83, 16, 34, 47, 59, 166, 68, 32, 197, 52, 3, 51, 190, 66, 94, 170,
        196, 116, 138, 184, 181, 137, 128, 14, 55, 140, 76, 135, 121, 88, 124, 85, 130, 43, 162, 24, 180,
        20, 63, 155, 107, 96, 134, 175, 69, 82, 109, 56, 115, 136, 70, 80, 41],
    3: [10, 92, 103, 86, 189, 64, 179, 147, 13, 53, 198, 1, 72, 48, 77, 36, 42, 73],
    4: [39, 195, 126, 191, 99, 144, 160, 104, 159, 21, 161, 6, 176, 113, 168, 102, 194, 148, 30, 119,
        87, 27, 106, 2, 143, 74, 79, 132, 178, 101, 28, 186, 97, 111, 91, 117, 127, 22, 71, 118, 44, 177,
        153, 172],
    5: [81, 151, 0, 37, 33, 11, 141, 112, 183, 149, 7, 173, 125, 108, 185, 25, 129, 163, 84, 54, 26, 152,
        78, 38, 188, 4, 95, 98, 57, 131, 90, 46, 8, 164, 93]
}
# Create dataset for each task
train_datasets = {}
val_datasets = {}
for t in tasks.keys():
    split = dataset['train'].filter(lambda img: img['label'] in tasks[t]).train_test_split(test_size=0.2)
    train_datasets[t] = split['train']
    val_datasets[t] = split['test']
    val_datasets[t] = val_datasets[t].map(process_example)
    val_datasets[t].set_format("pt", columns=['pixel_values'], output_all_columns=True)



In [None]:
# train baseline models
config = ResNetConfig(num_labels=200, num_channels=3)
metric = evaluate.load("accuracy")
models = {}

for t in tasks.keys():
    models[t] = ResNetForImageClassification(config)
    training_args = TrainingArguments(output_dir=f"./task_{t}", evaluation_strategy="epoch", num_train_epochs=50,)
    train_data = train_datasets[t].map(process_example)
    train_data.set_format("pt", columns=['pixel_values'], output_all_columns=True)
    trainer = Trainer(
        model=models[t],
        args=training_args,
        train_dataset=train_data,
        eval_dataset=dataset['valid'],
    )
    trainer.train()


In [57]:
# evaluate baseline models on combined validation dataset
for t in tasks.keys():
    predictions = []
    with torch.no_grad():
        for example in dataset['valid']:
            predicted_label = models[t](example['pixel_values'].unsqueeze(0).cuda()).logits.argmax(-1).item()
            predictions.append(predicted_label)
    acc = accuracy_score(dataset['valid']['label'], predictions)
    print('Task {} model\tAccuracy: {:.2f}%'.format(t,acc))

Task 1 model	Accuracy: 0.15%
Task 2 model	Accuracy: 0.20%
Task 3 model	Accuracy: 0.07%
Task 4 model	Accuracy: 0.15%
Task 5 model	Accuracy: 0.12%


In [None]:
# evaluate baseline models on tasks only validation dataset
for t in tasks.keys():
    predictions = []
    with torch.no_grad():
        for example in val_datasets[t]:
            predicted_label = models[t](example['pixel_values'].unsqueeze(0).cuda()).logits.argmax(-1).item()
            predictions.append(predicted_label)
    acc = accuracy_score(val_datasets[t]['label'], predictions)
    print('Task {}\tAccuracy: {:.2f}%'.format(t,acc))

In [None]:
# load models
models = {}
for t in tasks.keys():
    models[t] = ResNetForImageClassification.from_pretrained(f'task_{t}')

# Selection and Aggregation Schemes

In [59]:
# random selection
random.seed(2)
predictions = []
with torch.no_grad():
    for example in dataset['valid']:
        m = random.choice(list(models.keys()))
        predicted_label = models[m](example['pixel_values'].unsqueeze(0).cuda()).logits.argmax(-1).item()
        predictions.append(predicted_label)
acc = accuracy_score(dataset['valid']['label'], predictions)
print('Accuracy: {:.2f}%'.format(acc))

Accuracy: 0.1409


In [40]:
# Oracle Selection
predictions = []
with torch.no_grad():
    for example in dataset['valid']:
        m = None
        for t in tasks.keys():
            if example['label'] in tasks[t]:
                m = t
                break
        predicted_label = models[m](example['pixel_values'].unsqueeze(0).cuda()).logits.argmax(-1).item()
        predictions.append(predicted_label)
acc = accuracy_score(dataset['valid']['label'], predictions)
print('Accuracy: {:.2f}%'.format(acc))

Accuracy: 0.69%


In [61]:
# Confidence-based Selection
predictions = []
with torch.no_grad():
    for example in dataset['valid']:
        candidates = {}
        for t in tasks.keys():
            logits =  models[t](example['pixel_values'].unsqueeze(0).cuda()).logits
            predicted_label = logits.argmax(-1).item()
            pred_confidence = logits.max().item()
            candidates[predicted_label] = pred_confidence
        predicted_label = max(candidates, key=candidates.get)
        predictions.append(predicted_label)
acc = accuracy_score(dataset['valid']['label'], predictions)
print('Accuracy: {:.2f}%'.format(acc))

Accuracy: 0.4248


In [62]:
# Entropy-based Selection
def entropy(logits):
    probs = torch.softmax(logits, dim=1)
    log_probs = torch.log(probs + 1e-7) # Add a small epsilon to avoid log(0)
    entropy = -torch.sum(probs * log_probs, dim=1)
    return entropy.item()
predictions = []
with torch.no_grad():
    for example in dataset['valid']:
        candidates = {}
        for t in tasks.keys():
            logits =  models[t](example['pixel_values'].unsqueeze(0).cuda()).logits
            predicted_label = logits.argmax(-1).item()
            entr = entropy(logits)
            candidates[predicted_label] = entr
        predicted_label = min(candidates, key=candidates.get)
        predictions.append(predicted_label)
acc = accuracy_score(dataset['valid']['label'], predictions)
print('Accuracy: {:.2f}%'.format(acc))

Accuracy: 0.4364


In [5]:
# Stacking
# Define Stacked model
class StackingEnsemble(nn.Module):
    def __init__(self, models, num_classes=200, hidden_size=512):
        super(StackingEnsemble, self).__init__()
        self.models = nn.ModuleList(models)
        self.num_classes = num_classes
        self.pooler = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(len(models) * hidden_size, num_classes)
        )

    def forward(self, pixel_values: torch.FloatTensor = None,):
        # Pass input through the ensemble
        outputs = []
        for model in self.models:
            last_hidden_state = model(pixel_values, return_dict=True, output_hidden_states=True)[
                'hidden_states'][-1]
            pooled_output = self.pooler(last_hidden_state)
            outputs.append(pooled_output)

        # Stack the penultimate layer feature representations
        stacked_features = torch.cat(outputs, dim=1)
        stacked_features = stacked_features.view(stacked_features.size(0), -1)

        # Pass stacked features through the classification head
        logits = self.classifier(stacked_features)

        return logits


In [None]:
# Dataset 
# Train baseline models on 80% of train data
config = ResNetConfig.from_json_file('task_1\config.json')
metric = evaluate.load("accuracy")
models = {}
# Keep 20% of each task to finetune Stacked model
finetune_data = []
for t in tasks.keys():
    models[t] = ResNetForImageClassification(config)
    training_args = TrainingArguments(output_dir=f"./resnet_task_80%_{t}", num_train_epochs=10,
                                      save_total_limit=1, overwrite_output_dir=True, auto_find_batch_size=True, save_strategy='epoch')
    split = train_datasets[t].map(
        process_example).train_test_split(test_size=0.2)
    finetune_data.append(split['test'])
    split['train'].set_format(
        "pt", columns=['pixel_values'], output_all_columns=True)
    trainer = Trainer(
        model=models[t],
        args=training_args,
        train_dataset=split['train'],
    )
    trainer.train()

In [43]:
# Train Stacked model
stacking_ensemble = StackingEnsemble(list(models.values()))
stacked_dataset = concatenate_datasets(finetune_data).with_format("torch")
train_dataloader = torch.utils.data.DataLoader(stacked_dataset, batch_size=128)
num_epochs = 3
# Freeze all other params to train only the last classification layer
for name, param in stacking_ensemble.named_parameters():
    if name not in ['classifier.1.weight', 'classifier.1.bias']:
        param.requires_grad = False
optimizer = optim.Adam(stacking_ensemble.parameters(), lr=0.001) 
criterion = nn.CrossEntropyLoss() 
for epoch in range(num_epochs):
    print(f'EPOCH: {epoch}')
    for ex in train_dataloader:
        logits = stacking_ensemble(ex['pixel_values'])
        loss = criterion(logits, ex['label'])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


EPOCH: 0


100%|██████████| 157/157 [27:33<00:00, 10.53s/it]


EPOCH: 1


100%|██████████| 157/157 [27:22<00:00, 10.46s/it]


EPOCH: 2


100%|██████████| 157/157 [28:03<00:00, 10.72s/it]


In [27]:
# Evaluate the StackingEnsemble model on the combined validation set
# Define validation dataloader
val_dataset = dataset['valid'].with_format("torch")
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=128)

device = 'cuda'
correct = 0
total = 0

stacking_ensemble.eval()
stacking_ensemble.to(device)
with torch.no_grad():
    # Loop through the data loader
    for ex in  val_dataloader:
        # Move images and labels to device
        images = ex['pixel_values'].to(device)
        labels = ex['label'].to(device)

        # Forward pass to get model predictions
        outputs = stacking_ensemble(images)

        # Get the predicted labels as the index of the maximum output value
        _, predicted = torch.max(outputs.data, 1)

        # Update total samples
        total += labels.size(0)

        # Update correct predictions
        correct += (predicted == labels).sum().item()

# Calculate accuracy
accuracy = (correct / total) * 100

print('Accuracy: {:.2f}%'.format(accuracy))

Accuracy: 0.50%


In [47]:
# save model
torch.save({
    'epoch': epoch,
    'model_state_dict': stacking_ensemble.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss}, 'stacking_ensemble.pt')
