# Semi-supervised Food Image Classification

## Data Preparation and Initialization

### Import Packages

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import ConcatDataset, DataLoader, Subset, SequentialSampler, Dataset
from torchvision.datasets import DatasetFolder

from warmup_scheduler import GradualWarmupScheduler # https://github.com/ildoonet/pytorch-gradual-warmup-lr

from tqdm import tqdm

### Download Dataset

In [None]:
!gdown '1vufDjKxj4IwRni11uxjM0CSim5WehscA' --output food-11.zip
!unzip -q food-11.zip

### Dataset

In [None]:
class pseudo_dataset(Dataset):
    def __init__(self, data, target):
        self.data = torch.FloatTensor(data)
        self.target = target
        print(len(self.data))

    def __getitem__(self, index):
        img = self.data[index]
        lbl = self.target[index]
        
        return img, lbl

    def __len__(self):
        return len(self.data)

### Data Augmentation

In [None]:
inputs_size = 200
train_ori = transforms.Compose([
    # Resize the image into a fixed shape
    transforms.Resize((inputs_size, inputs_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

train_tfm = transforms.Compose([
    # Resize the image into a fixed shape
    transforms.Resize((inputs_size, inputs_size)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomAffine(45, translate=(0.1, 0.1), shear=[25, 25]),
    transforms.RandomPerspective(distortion_scale=0.2, p=0.5, fill=0),
    transforms.ColorJitter(brightness=[0.9, 1.3], contrast=0.3, saturation=0.1, hue=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    transforms.RandomErasing(p=0.3, scale=(0.02, 0.2)),
])

test_tfm = transforms.Compose([
    transforms.Resize((inputs_size, inputs_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

### Dataset Construction

In [None]:
batch_size = 32

train_set_tra = DatasetFolder("food-11/training/labeled", loader=lambda x: Image.open(x), extensions="jpg", transform=train_tfm)
train_set_ori = DatasetFolder("food-11/training/labeled", loader=lambda x: Image.open(x), extensions="jpg", transform=train_ori)
train_set = ConcatDataset([train_set_tra, train_set_ori])
valid_set = DatasetFolder("food-11/validation", loader=lambda x: Image.open(x), extensions="jpg", transform=test_tfm)
unlabeled_set = DatasetFolder("food-11/training/unlabeled", loader=lambda x: Image.open(x), extensions="jpg", transform=train_tfm)
test_set = DatasetFolder("food-11/testing", loader=lambda x: Image.open(x), extensions="jpg", transform=test_tfm)

### Dataloader Construction

In [None]:
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

## Model Training

### Helper functions

In [None]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

### Psudo Labeling

In [None]:
def get_pseudo_labels(dataset, model, threshold=0.7):
    '''
    This functions generates pseudo-labels of a dataset using given model.
    It returns an instance of DatasetFolder containing images whose prediction confidences exceed a given threshold.
    '''

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Make sure the model is in eval mode.
    model.eval()
    # Define softmax function.
    softmax = nn.Softmax(dim=-1)

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4,
        pin_memory=True, drop_last=True)

    data = []
    target = []
    pseudo_label = []
    # Iterate over the dataset by batches.
    for batch in tqdm(dataloader):
    # for i, batch in enumerate(dataLoader):
        img, _ = batch

        # Forward the data
        # Using torch.no_grad() accelerates the forward process.
        with torch.no_grad():
            logits = model(img.to(device))

        # Obtain the probability distributions by applying softmax on logits.
        probs = softmax(logits)

        # Filter the data and construct a new dataset.
        probs = [torch.argmax(x) if torch.max(x) > threshold else -1 for x in probs]
        pseudo_label = np.array(probs)
        pseudo_indice = np.where(pseudo_label != -1)
        pseudo_indice = list(pseudo_indice[0])
        for i in pseudo_indice:
            data.append(img[i].tolist())
            target.append(pseudo_label[i].item())

    pseudo_set = pseudo_dataset(data, target)

    print()

    # # Turn off the eval mode.
    model.train()
    return pseudo_set

### Hyperparameters

In [None]:
# "cuda" only when GPUs are available.
device = "cuda" if torch.cuda.is_available() else "cpu"

# Initialize a model, and put it on the device specified.
model = torchvision.models.resnet18(pretrained=False).to(device)
model.device = device
model.load_state_dict(torch.load('model.ckpt'))
save_path = 'model_pseudo.ckpt'

n_epochs = 500
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0008, weight_decay=2e-5)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=0.1*n_epochs, after_scheduler=scheduler)


### Semi-Supervised Learning Initialization

In [None]:
do_semi = True

ratio_top = 0.7
ratio_bottom = 0.2

top_threshold = 0.99
bottom_threshold = 0.995

tmp_pseudoset = Subset(unlabeled_set, [])
best_acc = 0
valid_acc = 0
train_acc = 0
last_acc = 0

switch = False
check = True

### Model Training

In [None]:
epoch = 0
while epoch < n_epochs:
    # In each epoch, relabel the unlabeled dataset for semi-supervised learning.
    # Then combine the labeled dataset and pseudo-labeled dataset for the training.

    threshold = bottom_threshold

    if best_acc>0.7:
        
        # Obtain pseudo-labels for unlabeled data using trained model.
        if valid_acc >= last_acc:
            last_acc = valid_acc
            pseudo_set = get_pseudo_labels(unlabeled_set, model, threshold)
            tmp_pseudoset = pseudo_set

        # Construct a new dataset and a data loader for training.
        # This is used in semi-supervised learning only.
        concat_dataset = ConcatDataset([train_set, tmp_pseudoset])

        train_loader = DataLoader(concat_dataset, batch_size=batch_size, shuffle=True
                        , num_workers=8, pin_memory=True, drop_last=True)

    # ---------- Training ----------
    # Make sure the model is in train mode before training.
    model.train()
    scheduler_warmup.step(epoch)

    # These are used to record information in training.
    train_loss = []
    train_accs = []

    # Iterate the training set by batches.
    for batch in tqdm(train_loader):
        # A batch consists of image data and corresponding labels.
        imgs, labels = batch

        # Forward the data.
        logits = model(imgs.to(device))

        # Calculate the cross-entropy loss.
        loss = criterion(logits, labels.to(device))

        # Gradients stored in the parameters in the previous step should be cleared out first.
        optimizer.zero_grad()

        # Compute the gradients for parameters.
        loss.backward()

        # Clip the gradient norms for stable training.
        grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)

        # Update the parameters with computed gradients.
        optimizer.step()

        # Compute the accuracy for current batch.
        acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

        # Record the loss and accuracy.
        train_loss.append(loss.item())
        train_accs.append(acc)

    # The average loss and accuracy of the training set is the average of the recorded values.
    train_loss = sum(train_loss) / len(train_loss)
    train_acc = sum(train_accs) / len(train_accs)

    # Print the information.
    print(f"[ Train | {epoch + 1:03d}/{n_epochs:03d} ] loss = {train_loss:.5f}, acc = {train_acc:.5f}")
    print(f'lr = {get_lr(optimizer):3.6f}, threshold: {threshold:.4f}')

    # ---------- Validation ----------
    # Make sure the model is in eval mode so that some modules like dropout are disabled and work normally.
    model.eval()

    # These are used to record information in validation.
    valid_loss = []
    valid_accs = []

    # Iterate the validation set by batches.
    for batch in tqdm(valid_loader):
        # A batch consists of image data and corresponding labels.
        imgs, labels = batch

        # Using torch.no_grad() accelerates the forward process.
        with torch.no_grad():
          logits = model(imgs.to(device))

        # Compute the loss
        loss = criterion(logits, labels.to(device))

        # Compute the accuracy for current batch.
        acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

        # Record the loss and accuracy.
        valid_loss.append(loss.item())
        valid_accs.append(acc)

    # The average loss and accuracy for entire validation set is the average of the recorded values.
    valid_loss = sum(valid_loss) / len(valid_loss)
    valid_acc = sum(valid_accs) / len(valid_accs)
    if valid_acc > best_acc:
        best_acc = valid_acc
        torch.save(model.state_dict(), save_path)
        print(f'saving model with acc {best_acc:.5f}')

    # Print the information.
    print(f"[ Valid | {epoch + 1:03d}/{n_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f}")
    print()
    epoch+=1

### Model Testing

In [None]:
# Make sure the model is in eval mode.
# Some modules like Dropout or BatchNorm affect if the model is in training mode.
model.load_state_dict(torch.load(save_path))
model.eval()

# Initialize a list to store the predictions.
predictions = []

# Iterate the testing set by batches.
for batch in tqdm(test_loader):
    # A batch consists of image data and corresponding labels.
    # But here the variable "labels" is useless since we do not have the ground-truth.
    # If printing out the labels, it is always 0.
    # This is because the wrapper (DatasetFolder) returns images and labels for each batch,
    # so we have to create fake labels to make it work normally.
    imgs, labels = batch

    # We don't need gradient in testing, and we don't even have labels to compute loss.
    # Using torch.no_grad() accelerates the forward process.
    with torch.no_grad():
        logits = model(imgs.to(device))

    # Take the class with greatest logit as prediction and record it.
    predictions.extend(logits.argmax(dim=-1).cpu().numpy().tolist())

# Save predictions into the file.
with open("predict.csv", "w") as f:

    # The first row must be "Id, Category"
    f.write("Id,Category\n")

    # For the rest of the rows, each image id corresponds to a predicted class.
    for i, pred in  enumerate(predictions):
         f.write(f"{i},{pred}\n")