<a href="https://colab.research.google.com/github/Benetell/marketing_memes/blob/main/resnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install "bayesian-optimization==1.3.1"  --quiet
!pip install lightning  --quiet
!pip install scikit-learn torchvision numpy seaborn  --quiet
!pip install transformers --quiet

import pandas as pd
import os
import multiprocessing as mp
mp.set_start_method('spawn', force=True)
os.environ["OMP_NUM_THREADS"] = "1"  # Adjust based on your system's capability
import glob

# Then import other libraries and define your code
import random
from collections import defaultdict, Counter

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.metrics import confusion_matrix, average_precision_score
from sklearn.metrics import multilabel_confusion_matrix as ConfusionMatrix
from sklearn.model_selection import KFold

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, Subset
from transformers import get_cosine_schedule_with_warmup
from torchvision import transforms, models
from torchvision.datasets import ImageFolder
from torchvision.transforms import AutoAugment, AutoAugmentPolicy

import lightning.pytorch as pl
from lightning.pytorch import Trainer
from lightning.pytorch.strategies import DDPStrategy
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger

from torchmetrics.classification import Accuracy, Precision, Recall, F1Score, MulticlassF1Score

from bayes_opt import BayesianOptimization
from tqdm import tqdm




In [None]:
class ResNetWrapper(pl.LightningModule):
    def __init__(self, lr, sigmoid_threshold=0.5, dropout_rate = 0.5):
        super().__init__()
        self.lr = lr
        self.dropout_rate   = dropout_rate
        self.sigmoid_threshold = sigmoid_threshold
        self.validation_step_y_hats = []
        self.validation_step_ys = []

        # weights=self.backbone_weights
        backbone = models.resnet50(pretrained=True)
        layers = list(backbone.children())[:-1]
        self.feature_extractor = nn.Sequential(*layers)

        for param in self.feature_extractor.parameters():
            param.requires_grad = False

        self.classifier = nn.Sequential(
            nn.Dropout(self.dropout_rate), # hyperparameter
            nn.Linear(backbone.fc.in_features, out_features=80)
        )

        self.loss_fn = nn.BCEWithLogitsLoss()

        self.train_acc = Accuracy(task="multilabel",num_labels=80)
        self.val_acc = Accuracy(task="multilabel",num_labels=80)
        self.train_precision = Precision(task="multilabel",num_labels=80)
        self.val_precision = Precision(task="multilabel",num_labels=80)
        self.train_recall = Recall(task="multilabel",num_labels=80)
        self.val_recall = Recall(task="multilabel",num_labels=80)
        self.train_f1 = F1Score(task="multilabel",num_labels=80)
        self.val_f1 = F1Score(task="multilabel",num_labels=80)


    def forward(self, x):
        x = self.feature_extractor(x)
        x = torch.flatten(x, 1)
        logits = self.classifier(x)
        return logits

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        preds = torch.sigmoid(logits) > self.sigmoid_threshold

        # Update metrics
        self.train_acc.update(preds, y)
        self.train_precision.update(preds, y)
        self.train_recall.update(preds, y)
        self.train_f1.update(preds, y)
        self.log('train_acc', self.train_acc)
        self.log('train_precision', self.train_precision)
        self.log('train_recall', self.train_recall)
        self.log('train_f1', self.train_f1)

        return loss
    def on_train_epoch_end(self):
        # Reset training metrics at the end of the epoch
        self.train_acc.reset()
        self.train_precision.reset()
        self.train_recall.reset()
        self.train_f1.reset()


    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)

        preds = torch.sigmoid(logits) > self.sigmoid_threshold
        self.val_acc.update(preds, y)
        self.val_precision.update(preds, y)
        self.val_recall.update(preds, y)
        self.val_f1.update(preds, y)

        self.log('val_loss', loss, on_epoch=True, prog_bar=True, on_step=False)
        return {"val_loss": loss}



    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        preds = torch.sigmoid(logits) > self.sigmoid_threshold

        self.log("test_preds", preds)
        self.log("test_targets", y)


        self.validation_step_y_hats.append(preds.cpu())
        self.validation_step_ys.append(y.cpu())
        return {'preds': preds, 'targets': y}


    def on_validation_epoch_end(self):
        self.log('val_acc', self.val_acc.compute(), on_epoch=True, prog_bar=True)
        self.log('val_precision', self.val_precision.compute(), on_epoch=True, prog_bar=True)
        self.log('val_recall', self.val_recall.compute(), on_epoch=True, prog_bar=True)
        self.log('val_f1', self.val_f1.compute(), on_epoch=True, prog_bar=True)
        self.val_acc.reset()
        self.val_precision.reset()
        self.val_recall.reset()
        self.val_f1.reset()


    def on_train_epoch_end(self, outputs):
        self.train_acc.reset()
        self.train_precision.reset()
        self.train_recall.reset()
        self.train_f1.reset()
    def on_epoch_start(self):
        current_lr = self.optimizers().param_groups[0]['lr']
        self.log('lr', current_lr, on_epoch=True, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
        # Total number of training steps (epochs * steps_per_epoch)
        total_training_steps = (
            len(self.trainer.datamodule.train_dataloader()) // self.trainer.world_size
        ) * self.trainer.max_epochs
        warmup_steps = int(0.1 * total_training_steps)  # 10% of total steps for warmup

        scheduler = get_cosine_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=total_training_steps
        )
        return [optimizer], [{"scheduler": scheduler, "interval": "step"}]



Best threshold: 0.3527
Best validation loss: 0.0591

## train test split

In [None]:

def target_to_oh(target):
    NUM_CLASSES = 80  # Number of classes
    one_hot = torch.zeros(NUM_CLASSES)  # Create a tensor of zeros with shape (NUM_CLASSES,)
    one_hot[target-1] = 1  # Set the correct class index to 1
    return one_hot

class MyDataModule(pl.LightningDataModule):
    def __init__(self, train_data, val_data, test_dataset = None, batch_size=32, num_workers=4):
        super().__init__()
        self.train_data = train_data
        self.val_data = val_data
        self.test_data = test_dataset
        self.batch_size = batch_size
        self.num_workers = num_workers

    def train_dataloader(self):
        train_dataset_with_transform = [(x, target_to_oh(y)) for x, y in self.train_data]

        train_loader = DataLoader(train_dataset_with_transform, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
        # Print a few batches to debug

        return train_loader


    def val_dataloader(self):
        val_dataset_with_transform = [(x, target_to_oh(y)) for x, y in self.val_data]
        val_loader = torch.utils.data.DataLoader(
                    val_dataset_with_transform, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

        return val_loader
    def test_dataloader(self):
        test_dataset_with_transform = [(x, target_to_oh(y)) for x, y in self.test_data]
        return DataLoader(test_dataset_with_transform, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

In [None]:
def exclude_folders(dataset, excluded_folders):
    """
    Excludes images from specified folders in the dataset.

    Args:
        dataset (ImageFolder): The dataset to filter.
        excluded_folders (list): List of folder names to exclude.

    Returns:
        ImageFolder: The filtered dataset with updated idx_to_label mapping.
    """
    filtered_samples = []
    for sample in dataset.samples:
        folder_name = os.path.basename(os.path.dirname(sample[0]))
        if not any(folder_name == folder for folder in excluded_folders):
            filtered_samples.append(sample)

    # Update the dataset's samples and targets to match the filtered samples
    dataset.samples = filtered_samples
    dataset.targets = [sample[1] for sample in filtered_samples]

    # Rebuild idx_to_label after exclusion
    idx_to_label = {i: label for i, label in enumerate(sorted(set(dataset.targets)))}

    print(f"Excluded folders: {excluded_folders}")
    print(f"Remaining samples after exclusion: {len(dataset.samples)}")

    return dataset, idx_to_label


def create_train_test_split_proportional(dataset, test_ratio=0.1, seed=42, transform=None):
    """
    Create a train-test split proportional to the dataset's labels.

    Args:
        dataset (ImageFolder): The dataset to split.
        test_ratio (float): Proportion of the dataset to include in the test split.
        seed (int): Random seed for reproducibility.
        transform: Image transformation for preprocessing.

    Returns:
        train_dataset, test_dataset, idx_to_label: The train/test split datasets and label mapping.
    """
    random.seed(seed)

    # Group samples by label
    label_to_samples = defaultdict(list)
    for sample in dataset.samples:
        label_to_samples[sample[1]].append(sample)

    train_samples = []
    test_samples = []

    # Split the dataset into train and test samples
    for label, samples in label_to_samples.items():
        random.shuffle(samples)
        num_test = int(len(samples) * test_ratio)
        test_samples.extend(samples[:num_test])
        train_samples.extend(samples[num_test:])

    # Create ImageFolder datasets for train and test
    train_dataset = ImageFolder(dataset.root, transform=transform)
    train_dataset.samples = train_samples
    train_dataset.targets = [sample[1] for sample in train_samples]  # Update targets

    test_dataset = ImageFolder(dataset.root, transform=transform)
    test_dataset.samples = test_samples
    test_dataset.targets = [sample[1] for sample in test_samples]  # Update targets

    # Rebuild idx_to_label after split
    idx_to_label = {i: label for i, label in enumerate(sorted(set(train_dataset.targets)))}

    return train_dataset, test_dataset, idx_to_label



In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.IMAGENET),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


# Apply exclusions
excluded_folders = ["AXA images"]  # Replace with the names of folders you want to exclude
dataset = ImageFolder('/kaggle/input/25dataset/125images', transform=transform)
print(f"Original dataset size: {len(dataset)}")

# Exclude specified folders and get the updated dataset and label mapping
dataset, idx_to_label = exclude_folders(dataset, excluded_folders)

# Create the train-test split and update label mapping
train_dataset, test_dataset, idx_to_label = create_train_test_split_proportional(dataset, test_ratio=0.08, seed=42, transform=transform)
print()
print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}", '\n')
print("Unique class indices in dataset:", set([sample[1] for sample in dataset.samples]), '\n')
print("Number of unique classes:", len(set([sample[1] for sample in dataset.samples])), '\n')
print("Train label distribution:", Counter([sample[1] for sample in train_dataset.samples]), '\n')
print("Test label distribution:", Counter([sample[1] for sample in test_dataset.samples]), '\n')


## training

In [None]:


# KFold Cross-Validation
batch_size = 23
lr = 0.008274600824983372
threshold = 0.3527
kf = KFold(n_splits=5, shuffle=True, random_state=42)  # 5-fold cross-validation
val_losses = []

# Assuming you have `train_dataset` defined as an ImageFolder dataset
for fold, (train_idx, val_idx) in tqdm(enumerate(kf.split(np.arange(len(train_dataset)))), total=kf.get_n_splits(), desc="K-Fold Cross-Validation"):
    print(f"Fold {fold + 1}/{kf.n_splits}")


    # Split the dataset into train and validation subsets using indices
    train_subset = Subset(train_dataset, train_idx)
    val_subset = Subset(train_dataset, val_idx)

    # Create Data Module for each fold
    data_module = MyDataModule(train_subset, val_subset, batch_size, num_workers=1)

    # Setup the model for multilabel classification
    model = ResNetWrapper(lr=lr, sigmoid_threshold=threshold, dropout_rate=0.5)

    # Logger
    logger = TensorBoardLogger("lightning_logs", name=f"multilabel_training_fold_{fold}")

    # ModelCheckpoint to save only the best models (monitor F1 score for multilabel)
    checkpoint_callback = ModelCheckpoint(
        monitor="val_f1",  # Monitor validation F1 score
        mode="max",        # Maximize the F1 score
        save_top_k=1,      # Save only the best model
        filename="{epoch}-{val_f1:.4f}"
    )

    # EarlyStopping based on F1 score
    early_stopping = EarlyStopping(
        monitor="val_f1",  # Use F1 for stopping
        patience=3,        # Stop after 3 non-improving epochs
        mode="max"
    )

    # Setup the Trainer
    trainer = Trainer(
        logger=logger,
        max_epochs=30,
        devices='auto',  # Adjust based on your hardware
        accelerator="gpu",  # Use "gpu" or "tpu" based on availability
        strategy="ddp_notebook",
        callbacks=[early_stopping, checkpoint_callback],
    )

    # Train the model
    trainer.fit(model, datamodule=data_module)

    # Log validation loss for this fold
    val_losses.append(trainer.callback_metrics["val_loss"].item())

# After cross-validation, retrieve the best model path
best_model_path = checkpoint_callback.best_model_path
print("Best model saved at:", best_model_path)


In [None]:
import matplotlib.pyplot as plt

image, label = train_dataset[0]
plt.imshow(image.permute(1, 2, 0))  # A csatornákat (C, H, W) átrendezzük (H, W, C) formátumba
plt.title(f"Label: {label}")
plt.show()


In [None]:
!sudo apt install net-tools

In [None]:
import os
os.environ['MASTER_PORT'] = '29501'  # Change to a new port


## test

In [None]:
!pip install pretty-confusion-matrix
import matplotlib.pyplot as plt
import numpy
from sklearn import metrics

In [None]:
from sklearn.metrics import precision_recall_fscore_support, multilabel_confusion_matrix

best_model = ResNetWrapper.load_from_checkpoint(best_model_path, lr=lr, sigmoid_threshold= 0.3527, dropout_rate=0.5)
mxs = []
f1s = {}
def evaluate_model(model, test_loader, sigmoid_threshold=0.5):
    """
    Evaluate the model on the test data, including confusion matrix and F1 score for each label.
    """
    # Ensure model is on the appropriate device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()  # Set model to evaluation mode

    all_preds = []
    all_targets = []

    # Iterate through test DataLoader
    with torch.no_grad():
        for batch in test_loader:
            x, y = batch

            # Move data to the same device as the model
            x = x.to(device)
            y = y.to(device)

            logits = model(x)
            # Sigmoid activation + thresholding
            preds = (torch.sigmoid(logits) > sigmoid_threshold).cpu().numpy()
            targets = y.cpu().numpy()

            # Append batch predictions and targets
            all_preds.append(preds)
            all_targets.append(targets)

    # Concatenate all predictions and targets
    all_preds = np.concatenate(all_preds, axis=0)
    all_targets = np.concatenate(all_targets, axis=0)

    # Compute multi-class confusion matrix for all classes
    mxs.append(multilabel_confusion_matrix(all_targets, all_preds))

    # Calculate precision, recall, and F1 score for each label
    precision, recall, f1, _ = precision_recall_fscore_support(all_targets, all_preds, average=None)
    label_support = np.sum(all_targets, axis=0)
    print("Number of true samples for each label:", label_support)

    print("\nF1 Score for each label:")
    for i, f1_score in enumerate(f1):
        f1s[i]=f1_score

# Test the evaluation
test_data_module = MyDataModule(train_dataset, train_dataset, test_dataset, batch_size, num_workers=2)
test_loader = test_data_module.test_dataloader()
evaluate_model(best_model, test_loader, sigmoid_threshold=0.3527)


mean f1 score 0.4040

In [None]:
sorted(f1s.items(), key=lambda item: item[1])

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Assuming you already have the reverse index mapping: idx_to_label
mxs_0 = mxs[0]
# Calculate number of rows and columns for the subplots (80 matrices)
n_rows = 10  # You can adjust this for different grid sizes
n_cols = 8   # Adjusted to fit 80 matrices in a grid

# Set up the matplotlib figure
fig, axes = plt.subplots(n_rows, n_cols, figsize=(24, 16))  # Adjust the size as needed
axes = axes.ravel()  # Flatten the axes array to index them easily

# Plot each confusion matrix
for i, mx in enumerate(mxs_0):
    ax = axes[i]

    # Plot the heatmap with masked zeros
    sns.heatmap(mx, annot=True, fmt='d', vmin=0.5, vmax=1, cbar=False, ax=ax, square=True)

    # Set title with the index and label name
    label_name = idx_to_label.get(i, f"Label {i}")  # Get the label name or default to "Label i"
    ax.set_title(f'Confusion Matrix for {label_name} (Index {i})')
    ax.set_xlabel('Predicted', fontsize=14)
    ax.set_ylabel('Actual', fontsize=14)

# Remove the remaining empty subplots if needed
for j in range(i + 1, len(axes)):
    axes[j].axis('off')

plt.tight_layout()
plt.show()


## Bayesian optimization

In [None]:
# Define the objective function for Bayesian Optimization
batch_size = 23
#lr = 0.008274600824983372
def objective(threshold, lr):
    kf = KFold(n_splits=5, shuffle=True, random_state=42)  # 5-fold cross-validation
    val_losses = []

    for fold, (train_idx, val_idx) in enumerate(kf.split(np.arange(len(train_dataset)))):
        # Split the dataset into train and validation sets
        train_subset = torch.utils.data.Subset(train_dataset, train_idx)
        val_subset = torch.utils.data.Subset(train_dataset, val_idx)

        # Create Data Module for each fold
        data_module = MyDataModule(train_subset, val_subset, batch_size, num_workers=2)
        # Setup the model
        model = ResNetWrapper(lr=lr, sigmoid_threshold=threshold, dropout_rate=0.5)


            # Logger
        logger = TensorBoardLogger("lightning_logs", name=f"multilabel_training_fold_{fold}")

        # ModelCheckpoint to save only the best models
        checkpoint_callback = ModelCheckpoint(
            monitor="val_loss",
            mode="min",  # Minimize the val_loss
            save_top_k=1,  # Save only the best model
            filename="{epoch}-{val_loss:.2f}"
        )

        # EarlyStopping based on validation loss
        early_stopping = EarlyStopping(
            monitor="val_loss",
            patience=3,
            mode="min"
        )

        # Setup the Trainer
        trainer = Trainer(
            logger=logger,
            max_epochs=12,
            devices="auto",  # Adjust based on your hardware
            accelerator="gpu",  # Use "gpu" or "tpu" based on availability
            strategy="ddp_notebook",
            callbacks=[early_stopping, checkpoint_callback]
        )

        # Train the model
        trainer.fit(model, datamodule=data_module)

        # Get the validation loss (you need to log it properly during validation_step)
        if "val_loss" in trainer.callback_metrics:
            val_loss = trainer.callback_metrics["val_loss"].item()
        else:
            val_loss = float('inf')  # Or some default high value

        val_losses.append(val_loss)

    return -val_loss  # Return negative val_loss to maximize

# Define the search space
pbounds = {
    'threshold': (0.1, 0.7),
        'lr': (1e-6, 1e-2)
}
# Initialize Bayesian Optimization
optimizer = BayesianOptimization(
    f=objective,
    pbounds=pbounds,
    random_state=42,
    verbose=2
)

# Run the optimization
optimizer.maximize(
    init_points=5,  # Number of initial random evaluations
    n_iter=5  # Number of optimization iterations
)

In [None]:
# Print the best parameters
best_params = optimizer.max['params']
best_params['batch_size'] = int(best_params['batch_size'])
print("Best parameters: ", best_params)

In [None]:
print(f"Best threshold: {optimizer.max['params']['threshold']:.4f}")
print(f"Best validation loss: {-optimizer.max['target']:.4f}")

## Download

In [None]:
import os
import subprocess
from IPython.display import FileLink, display

def download_file(path, download_file_name):
    os.chdir('/kaggle/working/')
    zip_name = f"/kaggle/working/{download_file_name}.zip"
    command = f"zip {zip_name} {path} -r"
    result = subprocess.run(command, shell=True, capture_output=True, text=True)
    if result.returncode != 0:
        print("Unable to run zip command!")
        print(result.stderr)
        return
    display(FileLink(f'{download_file_name}.zip'))
download_file('/kaggle/working/lightning_logs/training_proper_test', 'out')

In [None]:
!zip -r file.zip /kaggle/working/lightning_logs/training_proper_test
from IPython.display import FileLink
FileLink(r'file.zip')

## Testing on the labeled logo dataset

In [None]:
!ls lightning_logs/training_proper_test/version_24/checkpoints/

In [None]:
from torchmetrics.classification import MulticlassConfusionMatrix

metric = MulticlassConfusionMatrix(num_classes=5)
metric.update(all_labels, all_preds)
fig_, ax_ = metric.plot()

In [None]:
import os
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from sklearn.metrics import f1_score
from glob import glob

# Assuming your test dataset is set up already
batch_size = 23
lr = 0.008274600824983372

# Create the test DataLoader
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

# Find all checkpoint files recursively across all versions
checkpoint_directory = '/kaggle/working/lightning_logs/multilabel_training_fold_2/'
checkpoint_files = glob(os.path.join(checkpoint_directory, '**', '*.ckpt'), recursive=True)

# Define a dictionary to store the performance of each checkpoint
checkpoint_performance = {}

# Iterate through each checkpoint file
for checkpoint in checkpoint_files:
    # Load the model from checkpoint
    model = best_model
    model.eval()
    model.to('cuda')

    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    # Evaluate model
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to('cuda'), labels.to('cuda')

            # Get model predictions
            outputs = model(inputs)
            probabilities = F.softmax(outputs, dim=1)
            _, predicted = torch.max(probabilities, 1)

            # Store predictions and labels
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            # Update accuracy
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    # Calculate F1 score and accuracy
    f1 = f1_score(all_labels, all_preds, average='macro')
    accuracy = 100 * correct / total

    # Store the performance metrics for the current checkpoint
    checkpoint_performance[checkpoint] = {'f1_score': f1, 'accuracy': accuracy}

    # Print results
    print(f"Checkpoint: {checkpoint}")
    print(f"F1 Score: {f1:.4f}")
    print(f"Accuracy: {accuracy:.2f}%\n")

# Find the best checkpoint based on F1 score
best_checkpoint = max(checkpoint_performance, key=lambda x: checkpoint_performance[x]['f1_score'])
best_f1_score = checkpoint_performance[best_checkpoint]['f1_score']
best_accuracy = checkpoint_performance[best_checkpoint]['accuracy']

print("\nBest Checkpoint:")
print(f"Path: {best_checkpoint}")
print(f"F1 Score: {best_f1_score:.4f}")
print(f"Accuracy: {best_accuracy:.2f}%")


In [None]:
import matplotlib.pyplot as plt
import torch

# Function to unnormalize images for display
def unnormalize(img_tensor, mean, std):
    img_tensor = img_tensor.clone()  # Clone the tensor to avoid modifying the original
    for t, m, s in zip(img_tensor, mean, std):
        t.mul_(s).add_(m)  # Unnormalize each channel
    return img_tensor

# Define mean and std from your transform for unnormalizing
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

# Create a function to plot a grid of images
def plot_image_grid(dataset, num_images=16):
    fig, axes = plt.subplots(4, 4, figsize=(8, 8))  # Grid of 4x4 for 16 images
    axes = axes.flatten()

    for i in range(num_images):
        img, label = dataset[i]
        img = unnormalize(img, mean, std)  # Unnormalize
        img = img.permute(1, 2, 0).numpy()  # Convert from CxHxW to HxWxC

        axes[i].imshow(img)
        axes[i].set_title(f"Label: {label}")
        axes[i].axis("off")

    plt.tight_layout()
    plt.show()

# Visualize some images from the training dataset
plot_image_grid(test_set, num_images=16)


In [None]:
import os
import glob


# Define the threshold for classification
threshold = 0.8
batch_size = 23
lr = 0.008274600824983372

# Initialize the trainer
trainer = Trainer()

results = []
checkpoint = "/kaggle/input/checkpoints/lightning_logs/the training/version_4/checkpoints/epoch=8-val_loss=1.39.ckpt"

# Load the test dataset
test_dataset = ImageFolder(root='/kaggle/input/brand-logos/test', transform=transform)

# Create the test DataLoader
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# Load the model from the checkpoint
model = ResNetWrapper.load_from_checkpoint(checkpoint, lr=lr, num_classes=100,
                                            backbone_weights=models.ResNet50_Weights.IMAGENET1K_V2)
print(isinstance(model, LightningModule))
model.eval()  # Set model to evaluation mode if necessary

# Test the model using the Trainer
test_results = trainer.test(model, test_loader)

# Process the test results to compute thresholds if needed
preds = test_results[0]['preds']
labels = test_results[0]['labels']

# Convert logits to probabilities and apply thresholding
probabilities = F.softmax(preds, dim=1)
predicted = (probabilities > threshold).long().max(dim=1)[1]

# Store results
accuracy = (predicted == labels).sum().item() / len(labels)
results.append({'checkpoint': checkpoint, 'accuracy': accuracy})

# Output all results
for result in results:
    print(f"Checkpoint: {result['checkpoint']} - Accuracy: {result['accuracy']:.4f}")


In [None]:
overall_f1_macro = f1_score(all_labels, all_preds, average='macro')   # Macro-average
print(overall_f1_macro)

### Saving best models

In [None]:

!ls /kaggle/working/checkpoints/lightning_logs/
!ls /kaggle/working/checkpoints/lightning_logs/version_6/checkpoints
!zip -r nyertes.zip /kaggle/working/checkpoints/lightning_logs/version_13/checkpoints/epoch=14-step=2190.ckpt


In [None]:
!zip -r version_5.zip /kaggle/working/checkpoints/lightning_logs/

In [None]:
!pip install -U ipywidgets
!pip install ipywidgets
!jupyter nbextension enable --py widgetsnbextension


In [None]:
!rm -rf /kaggle/working/*

## Test on unlabeled marketing memes

In [None]:
import os
from PIL import Image
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
# Define your brand names in the order they appear in your list
brand_names = [
    "3m", "axa", "accenture", "adobe", "airbnb", "allianz", "amazon", "americanexpress",
    "apple", "audi", "bmw", "budweiser", "burberry", "canon", "cartier", "caterpillar",
    "chanel", "cisco", "citi bank", "cocacola", "colgate", "corona", "dhl", "danone",
    "dior", "disney", "facebook", "fedex", "ferrari", "ford", "ge", "gillette",
    "goldmansachs", "google", "gucci", "hm", "hp", "hsbc", "heineken", "hennessy",
    "hermès", "hewlettpackardenterprise", "honda", "huawei", "hyundai", "ibm", "ikea",
    "instagram", "intel", "jpmorgan", "jackdaniels", "johnsonjohnson", "kfc", "kelloggs",
    "kia", "lego", "loréalparis", "linkedin", "louisvuitton", "mastercard", "mcdonalds",
    "mercedesbenz", "microsoft", "morganstanley", "nescafé", "nespresso", "nestlé",
    "netflix", "nike", "nintendo", "nissan", "oracle", "pampers", "panasonic", "paypal",
    "pepsi", "philips", "porsche", "prada", "redbull", "sap", "salesforce", "samsung",
    "santander", "sephora", "siemens", "sony", "spotify", "starbucks", "tesla",
    "tiffanyco", "toyota", "ups", "visa", "volkswagen", "xiaomi", "youtube", "zara",
    "adidas", "ebay"
]

class UnlabeledImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.image_paths = [os.path.join(image_dir, img) for img in os.listdir(image_dir) if img.lower().endswith(('png', 'jpg', 'jpeg'))]
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, img_path


# Load the model checkpoint
num_classes = 100  # Adjust based on your actual number of classes
model = ResNetWrapper.load_from_checkpoint(checkpoint,lr=lr, num_classes=num_classes, backbone_weights=models.ResNet50_Weights.IMAGENET1K_V2)

# Move the model to the GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()

transform = transforms.Compose([
    transforms.Resize((224, 224)),  # ResNet50 typically takes 224x224 input
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Define the prediction function
def classify_images(model, image_folder, transform, device):
    dataset = UnlabeledImageDataset(image_folder, transform=transform)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
    results = {}  # Ensure 'results' dictionary is initialized

    with torch.no_grad():
        for images, img_paths in dataloader:
            images = images.to(device)
            logits = model(images)
            probs = F.softmax(logits, dim=1)  # Convert logits to probabilities
            confidences, preds = torch.max(probs, dim=1)  # Get the highest probability and its index for each prediction

            for confidence, pred, img_path in zip(confidences, preds, img_paths):
                if confidence.item() >= 0.9:  # Only consider predictions with confidence >= 80%
                    brand = brand_names[pred.item()].lower()  # Map index to brand name

                    if brand not in results:
                        results[brand] = []
                    results[brand].append(os.path.basename(img_path))  # Append the filename, not the full path

    return results

# Define the folder containing images to classify
image_folder = "/kaggle/input/5memes-for-top-100-most-valuable-brand-2023"

# Classify images and store the results in a dictionary
dict1 = classify_images(model, image_folder, transform, device)


# Classify images and store the results in a dictionary
dict1 = classify_images(model, image_folder, transform, device)
dict2 = {'apple': ['Xiaomi memes_9daf00f7-8607-4ff6-b8e1-1dd5c755f43a.jpeg', 'Apple memes_iphone-chatgpt.jpg', 'Apple memes_Image_2.jpg', 'Huawei memes_Image_3.jpg'], 'microsoft': ['Microsoft memes_Image_5.jpg', 'Microsoft memes_Image_1.jpeg', 'Microsoft memes_Image_4.jpeg', 'HP memes_190d521ac62eaf40a74e50ab39635231.jpg'], 'amazon': ['Amazon memes_Image_1.jpg', 'Amazon memes_Image_2.jpg', 'Amazon memes_Image_4.jpeg', 'Amazon memes_Image_5.jpg', 'Amazon memes_Image_3.jpg'], 'google': ['Google memes_Image_2.jpeg', 'Google memes_Image_1.jpg', 'Xiaomi memes_Image_2.jpeg', 'Google memes_Image_4.jpg', 'Google memes_Image_3.jpg'], 'samsung': ['Samsung memes_Image_1.jpeg', 'Samsung memes_Image_2.jpg', 'Samsung memes_Image_3.jpg', 'Samsung memes_Image_4.jpg', 'Samsung memes_Image_5.jpeg'], 'toyota': ['Toyota memes_Image_1.jpg', 'Toyota memes_Image_2.jpg', 'Toyota memes_Image_4.jpg'], 'mercedes-benz': ['MercedesBenz memes_Image_5.jpg', 'MercedesBenz memes_Image_3.jpg'], 'coca-cola': ['CocaCola memes_Image_2.jpeg', 'Pepsi memes_Image_3.jpg', 'CocaCola memes_Image_1.jpg'], 'nike': [], 'bmw': ['BMW memes_ayNdM3q_460s.jpg', 'BMW memes_Image_4.jpg', 'Audi memes_Image_5.jpg', 'MercedesBenz memes_Image_1.jpg'], "mcdonald's": ['McDonalds memes_Image_2.jpg', 'McDonalds memes_Image_1.jpg', 'McDonalds memes_Image_2 (copy 1).jpg'], 'tesla': ['Tesla memes_Image_2.jpg', 'Tesla memes_teslamemes-191110030713-thumbnail.webp', 'Caterpillar memes_image-1.webp', 'Tesla memes_Image_5.jpg'], 'disney': ['Disney memes_Image_1.png', 'Disney memes_Image_2.jpg', 'Disney memes_relatable-disney-memes.png'], 'louis vuitton': ['LouisVuitton memes_Image_1.jpg', 'LouisVuitton memes_tgeycyn3z4k91.jpg', 'LouisVuitton memes_0a2163e267c3e9c175e7c08bf8253318.jpg', 'LouisVuitton memes_Image_2.jpg'], 'cisco': ['Cisco memes_d77c72918e7440d37fc3eabc307ad8ac3a8025383ea34cdb6e1df27c047628f6_1.jpg', 'Cisco memes_Image_3.jpg', 'Cisco memes_Image_1.jpg', 'Cisco memes_fc0j9uamxd371.png'], 'instagram': ['Instagram memes_Image_4.jpg', 'Instagram memes_ig.jpg', 'Heineken memes_Image_1.jpg', 'Instagram memes_Image_5.jpg', 'Airbnb memes_Image_1.jpg', 'Instagram memes_person-women-on-instagram-be-like-got-new-shoes.jpeg', 'Instagram memes_Image_2.jpg', 'HewlettPackardEnterprise memes_EPo6d7mUYAAHsU0.jpg'], 'adobe': ['Adobe memes_Image_3.jpeg', 'Adobe memes_Image_1.jpeg', 'Adobe memes_Image_2.jpeg'], 'ibm': ['IBM memes_Image_2.jpg', 'IBM memes_Image_1.jpg', 'IBM memes_Image_3.jpg'], 'oracle': ['Oracle memes_EXT38P1WkAAjs_5.jpg', 'Oracle memes_Image_5.jpg', 'Oracle memes_programmerhumor-io-databases-memes-backend-memes-ce04a4d894b6035-608x613.webp'], 'sap': ['SAP memes_Image_1.jpeg', 'SAP memes_Image_3.jpg', 'SAP memes_Image_5.png', 'SAP memes_Image_2.jpeg', 'SAP memes_Image_4.jpg'], 'facebook': ['Instagram memes_Image_4.jpg', 'Heineken memes_Image_1.jpg', 'Caterpillar memes_Bj9v-r7IUAApml9.png', 'Facebook memes_Image_1.jpg', 'Facebook memes_Screenshot2021-11-01at16.29.20.jpg', 'TiffanyCo memes_98facae07556bfaab7cb6672371ff06c.jpg', 'Facebook memes_f.jpg', 'Facebook memes_Image_4.jpg', 'HewlettPackardEnterprise memes_EPo6d7mUYAAHsU0.jpg'], 'chanel': ['Chanel memes_images (1).jpeg', 'Chanel memes_Screen-Shot-2021-04-29-at-8.webp', 'Chanel memes_Image_2.jpg', 'Chanel memes_images.jpeg'], 'hermes': ['Herms memes_434053634_18427064311043375_1898595488137274113_n.jpg', 'Herms memes_Fd_gt0bUUAAeCRh.jpg'], 'intel': ['Intel memes_Image_4 (copy 1).jpeg', 'Intel memes_Image_2.jpg', 'Intel memes_Image_4.jpeg', 'Intel memes_Image_1.jpg'], 'youtube': ['YouTube memes_Image_2.jpg', 'YouTube memes_Image_3.jpg'], 'j.p. morgan': ['JPMorgan memes_xq0z4545oqe41.webp', 'MorganStanley memes_Image_1.jpg'], 'honda': ['Honda memes_Image_4.jpg', 'Honda memes_Image_5.jpg'], 'american express': ['AmericanExpress memes_Image_3.jpg', 'AmericanExpress memes_Image_4.jpg', 'AmericanExpress memes_Image_1.jpg'], 'ikea': ['IKEA memes_Image_3.jpg', 'IKEA memes_Image_5.png', 'IKEA memes_Image_2.jpg', 'IKEA memes_Image_4.jpg'], 'accenture': ['Accenture memes_Image_4.jpg', 'Accenture memes_Image_1.png', 'Accenture memes_Image_2.jpg', 'Accenture memes_Image_5.jpg'], 'allianz': ['Allianz memes_amLOMzj_460s.jpg', 'Allianz memes_Image_2.jpg', 'Allianz memes_Image_1.jpg', 'Allianz memes_images.jpeg'], 'hyundai': ['Hyundai memes_Image_3.jpeg', 'Hyundai memes_Image_4.jpg'], 'ups': ['Cisco memes_d77c72918e7440d37fc3eabc307ad8ac3a8025383ea34cdb6e1df27c047628f6_1.jpg', 'UPS memes_Image_4.jpg', 'Ferrari memes_Image_2.jpg', 'DHL memes_Image_2.jpeg', 'UPS memes_Image_3.jpeg', 'UPS memes_Image_2.jpeg'], 'gucci': ['Gucci memes_Image_3.jpg', 'Gucci memes_Image_1.jpg', 'Gucci memes_Image_2 (copy 1).jpg', 'Gucci memes_Image_2.jpg'], 'pepsi': ['Pepsi memes_Image_2.jpeg', 'Pepsi memes_Image_3.jpg', 'Pepsi memes_Image_5.jpeg'], 'sony': ['Sony memes_Image_2.jpg', 'Sony memes_Image_3.jpeg', 'Sony memes_Image_2 (copy 1).jpg', 'Sony memes_Image_1.jpeg'], 'visa': ['Visa memes_Image_3.jpg', 'AmericanExpress memes_Image_2.jpg', 'Visa memes_Image_2.jpg', 'Visa memes_Image_5.jpg', 'Visa memes_Image_4.jpg'], 'salesforce': ['Salesforce memes_Image_1.jpeg', 'Salesforce memes_Image_3.jpeg', 'Salesforce memes_Image_5.jpeg', 'Salesforce memes_Image_2.jpeg', 'Salesforce memes_Image_4.jpg'], 'netflix': ['Netflix memes_Image_4.jpg', 'Cisco memes_d77c72918e7440d37fc3eabc307ad8ac3a8025383ea34cdb6e1df27c047628f6_1.jpg', 'Netflix memes_Image_5.jpg', 'Netflix memes_Image_1.jpg'], 'paypal': ['PayPal memes_Image_4.jpeg', 'PayPal memes_Image_2.jpg', 'PayPal memes_Image_3.jpeg', 'PayPal memes_Image_1.jpeg'], 'mastercard': ['Mastercard memes_Image_2.jpg', 'Mastercard memes_Image_4.jpg', 'Mastercard memes_Image_1.jpeg', 'AmericanExpress memes_Image_2.jpg', 'Visa memes_Image_5.jpg'], 'adidas': ['adidas memes_Image_1.jpeg'], 'zara': ['Zara memes_Image_1.jpeg', 'Zara memes_Image_3.jpeg', 'Zara memes_Image_5.jpg'], 'axa': ['AXA memes_gettyimages-1226469012-612x612.jpg'], 'audi': ['Audi memes_Image_5.jpg', 'MercedesBenz memes_Image_1.jpg', 'Audi memes_Image_4.jpg'], 'airbnb': ['Airbnb memes_Image_3.png', 'Airbnb memes_Image_1.jpg', 'Airbnb memes_Image_4.jpeg', 'Airbnb memes_Image_2.jpg', 'Airbnb memes_P4BwmWt.jpg'], 'porsche': ['Porsche memes_Image_2.jpeg', 'Porsche memes_Image_1.jpg', 'Porsche memes_Image_4.jpg'], 'starbucks': ['Starbucks memes_Image_4.jpg', 'Starbucks memes_Image_2.png', 'Starbucks memes_4e59a310d37f67cbb82bcd067ee4e0c7.jpg', 'Starbucks memes_Image_3.png'], 'ge': ['GE memes_images (2).jpeg'], 'volkswagen': ['Volkswagen memes_Image_3.jpg', 'Volkswagen memes_Image_1.jpeg'], 'ford': ['Ford memes_Image_5.jpeg'], 'nescafé': [], 'siemens': ['Siemens memes_Image_5.jpg', 'Siemens memes_siemens-heh-heh-heh-v0-yf2fbabx814b1.webp', 'Siemens memes_Image_4.jpg', 'Siemens memes_A6Db4nQCQAEw2Xg.jpg'], 'goldman sachs': ['GoldmanSachs memes_Image_5.jpg', 'GoldmanSachs memes_Image_2.jpg', 'GoldmanSachs memes_Image_3.jpg', 'GoldmanSachs memes_Image_4.jpg'], 'pampers': ['Pampers memes_Image_3.jpg', 'Pampers memes_Image_5 (copy 1).jpg', 'Pampers memes_Image_5.jpg'], 'h&m': ['HM memes_tandem-x-visuals-FZOOxR2auVI-unsplash-1313x900.webp', 'HM memes_Image_4.jpeg'], 'l’oréal paris': ['LOralParis memes_Image_3.jpg', 'LOralParis memes_Image_1.jpg', 'LOralParis memes_Image_2.jpeg'], 'citi': ['Cisco memes_d77c72918e7440d37fc3eabc307ad8ac3a8025383ea34cdb6e1df27c047628f6_1.jpg', 'Citi Bank memes_Image_4.jpg', 'MorganStanley memes_Image_3.jpg', 'Citi Bank memes_citi.jpg', 'Citi Bank memes_Image_2.jpg', 'Citi Bank memes_Image_5.jpg'], 'lego': ['LEGO memes_Image_5.jpg', 'LEGO memes_images.jpeg', 'LEGO memes_Image_4.jpeg'], 'red bull': ['JackDaniels memes_Image_2.jpg', 'RedBull memes_Image_5.jpg', 'RedBull memes_Image_2.jpg', 'RedBull memes_Image_4.jpeg', 'RedBull memes_Image_3.jpg'], 'budweiser': ['Budweiser memes_Image_3.jpg', 'Budweiser memes_aoKL853_460s.jpg', 'Budweiser memes_Image_2.jpg', 'Budweiser memes_Image_4.jpg', 'Budweiser memes_Image_1.jpg'], 'ebay': ['eBay memes_Image_4.jpg', 'eBay memes_Image_2.jpg'], 'nissan': ['Nissan memes_Image_2.jpeg', 'Nissan memes_Image_4.jpg', 'Nissan memes_Image_3.jpeg', 'Nissan memes_Image_3.jpg'], 'hp': ['HewlettPackardEnterprise memes_k8m3z.jpg', 'HP memes_38ef2b77327f4020f04dc14dc41f1538.jpg', 'HP memes_190d521ac62eaf40a74e50ab39635231.jpg'], 'hsbc': ['HSBC memes_Image_1.jpg', 'HSBC memes_Image_3.jpg', 'HSBC memes_images.jpeg'], 'morgan stanley': ['MorganStanley memes_4521ccf691de4bacef41ccbb143387b2.jpg', 'MorganStanley memes_Image_4.jpg', 'MorganStanley memes_images.jpeg', 'MorganStanley memes_Image_1.jpg'], 'nestle': ['Nestl memes_Image_1.jpeg', 'Nestl memes_Image_2.jpeg', 'Nestl memes_Image_3.jpeg', 'Nestl memes_Image_4.jpeg', 'Nestl memes_Image_5.jpg'], 'philips': ['Philips memes_Image_4.jpeg', 'Philips memes_images.jpeg', 'Philips memes_Philips-Innovationandyou-1002x1417.jpg', 'Philips memes_unnamed.jpg'], 'spotify': ['Spotify memes_spotify-now-listen-most-boring-pop-charts-yes-hey-xx-xx-tap-banner-now-last-ad-just-two-songs-ago.png', 'Spotify memes_Image_5.jpeg'], 'ferrari': ['Ferrari memes_Image_5.jpg', 'Ferrari memes_Image_1.jpg', 'Ferrari memes_Image_1.jpeg', 'Ferrari memes_Image_2.jpg', 'Ferrari memes_Image_3.png'], 'nintendo': ['Philips memes_Image_4.jpeg', 'Nintendo memes_Image_5.jpeg', 'Nintendo memes_Image_4.jpeg', 'Panasonic memes_Image_3.jpeg'], 'gillette': ['Gillette memes_Image_1.jpeg', 'Gillette memes_Image_2.jpg', 'Gillette memes_Image_4.jpeg', 'Gillette memes_Image_5.jpg'], 'colgate': ['Colgate memes_Image_4.jpeg', 'Colgate memes_Image_2.jpg', 'LOralParis memes_Image_2.jpeg'], 'cartier': ['Cartier memes_CaGpc-1VIAAOpWc.png', 'Cartier memes_Image_3.jpg', 'Cartier memes_Image_1.jpeg', 'Cartier memes_65c080cea63c4.jpeg'], '3m': ['3M memes_Image_2.jpg', '3M memes_Image_5.jpg', '3M memes_Image_3.jpg', '3M memes_khkn86l1vgs41.webp'], 'dior': ['Dior memes_Image_3.jpg', 'Dior memes_Image_5.jpg', 'Dior memes_Image_2.jpg', 'Dior memes_Image_4.jpeg', 'Dior memes_Image_4.jpg'], 'santander': ['Santander memes_Image_2.jpg', 'Santander memes_Image_3.jpg', 'Santander memes_Image_4.jpg', 'Santander memes_meme1.webp'], 'danone': ['Danone memes_Image_2.jpg', 'Danone memes_Image_3.jpg', 'Danone memes_5lmocm.jpg', 'Danone memes_5b1d46878de83.jpeg', 'Danone memes_page_1.webp'], "kellogg's": ['Kelloggs memes_Image_3.jpeg', 'Kelloggs memes_Image_3.jpg'], 'linkedin': ['LinkedIn memes_Image_2.jpg'], 'corona': ['Corona memes_om9ysstu92k41.jpg'], 'fedex': ['FedEx memes_Image_1.jpg', 'FedEx memes_Image_5.jpg', 'FedEx memes_Image_2.jpeg', 'FedEx memes_Image_3.jpg'], 'caterpillar': ['Caterpillar memes_image-1.webp'], 'dhl': ['DHL memes_Image_5.jpg', 'DHL memes_Image_4.png', 'DHL memes_Image_1.jpg', 'DHL memes_Image_2.jpeg'], "jack daniel's": ['JackDaniels memes_Image_4.jpg'], 'prada': ['Prada memes_highxtar-prada-ss21-campaign-4.jpg', 'Prada memes_highxtar-prada-ss21-campaign-1.jpg', 'Prada memes_DIETPRADA3.webp', 'Prada memes_highxtar-prada-ss21-campaign-2.jpg'], 'xiaomi': ['Xiaomi memes_9daf00f7-8607-4ff6-b8e1-1dd5c755f43a.jpeg', 'Xiaomi memes_Image_4.jpeg', 'Xiaomi memes_Image_1.jpeg'], 'kia': [], 'tiffany & co.': [], 'panasonic': ['Panasonic memes_Image_5.jpg', 'Panasonic memes_Image_4.jpeg', 'Panasonic memes_Image_1.jpeg'], 'hewlett packard enterprise': ['hp enterprise', 'hewlett packard', 'HewlettPackardEnterprise memes_maxresdefault.jpg', 'HewlettPackardEnterprise memes_EPo6d7mUYAAHsU0.jpg'], 'huawei': ['Huawei memes_Image_5.jpg', 'Huawei memes_u8978vigrj031.webp', 'Huawei memes_Image_3.jpg'], 'hennessy': ['Hennessy memes_Image_4.jpg', 'Hennessy memes_Image_3.jpg', 'Hennessy memes_Image_2.jpg', 'Hennessy memes_Image_1.jpg', 'Hennessy memes_Image_5.jpg'], 'burberry': ['Burberry memes_images.jpeg', 'Burberry memes_Image_3.jpg', 'Burberry memes_4febf2b9add2cf73f8e3bcca6f67c904.jpg', 'Burberry memes_Image_5.jpg'], 'kfc': ['KFC memes_Image_4.jpg', 'KFC memes_Image_1.jpeg', 'KFC memes_Image_5.jpeg', 'KFC memes_Image_2 (copy 1).jpg'], 'johnson & johnson': ['JohnsonJohnson memes_jj-vaccine-covid-19-vaccine-memes-covid-19-memes-funny-memes-memes-twitter-memes-funny-tweets.jpeg', 'JohnsonJohnson memes_memes-for-use-to-spam-j-j-v0-2kmhfpfrrmbb1.webp', 'JohnsonJohnson memes_Image_1.jpg'], 'sephora': ['Sephora memes_Image_5.png', 'Sephora memes_Image_2.jpg', 'Sephora memes_Image_1.jpg', 'Sephora memes_Image_4.jpeg', 'Sephora memes_Image_3.jpg'], 'nespresso': ['Nespresso memes_Image_3.jpg', 'Nespresso memes_Image_1.jpg', 'Nespresso memes_Image_2.jpg', 'Nespresso memes_Image_4.jpg', 'Nespresso memes_Image_5.jpg'], 'heineken': ['Corona memes_images.jpeg', 'Heineken memes_Image_1.jpg', 'Heineken memes_Image_5.jpg'], 'canon': ['Canon memes_55bc7be8031ed.jpeg', 'Canon memes_7ugwbf.jpg', 'Canon memes_Image_1.jpg', 'Canon memes_xzgpcaalz69a1.jpg']}
result = {}

for key in dict1:
    if key in dict2:
        if key not in result:
            result[key] = dict1[key] + dict2[key]
    else:
        result[key] = dict1[key]

for key in dict2:
    if key not in result:
        result[key] = dict2[key]
print("finished")

In [None]:
for i, (key, value_list) in enumerate(dict1.items()):
    print(f"{i} \t {key}: {value_list}\n")

In [None]:
result

Dagsub connection, authentication

In [None]:
!rm -fr /kaggle/working/DIRPATH/model-epoch=14-val_loss=1.60.ckpt