## Install Vision Transformer (ViT)

In [None]:
! pip -q install vit_pytorch linformer

In [None]:
!pip install linformer

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, accuracy_score, recall_score, f1_score
import seaborn as sns

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms, utils
import torch.optim as optim
from torch.optim import lr_scheduler

import time
import os
import zipfile
from copy import deepcopy

%matplotlib inline

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(device)

# vision transformer
from linformer import Linformer
from vit_pytorch.efficient import ViT

In [None]:
## Custom Galaxy Zoo 2 Dataset
class GalaxyZooDataset(Dataset):
    """Galaxy Zoo Dataset"""

    def __init__(self, csv_file, images_dir, transform=None):
        """
        Args:
            csv_file (string): path to the label csv
            images_dir (string): path to the dir containing all images
            transform (callable, optional): transform to apply
        """
        self.labels_df = pd.read_csv(csv_file)
        self.labels_df = self.labels_df[['galaxyID', 'label1']].copy()

        self.images_dir = images_dir
        self.transform = transform
    
    def __len__(self):
        """
        Returns the size of the dataset
        """
        return len(self.labels_df)

    def __getitem__(self, idx):
        """
        Get the idx-th sample.
		Outputs the image (channel first) and the true label
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        # galaxy ID
        galaxyid = self.labels_df.iloc[idx, 0].astype(str)
		# path of the image
        image_path = os.path.join(self.images_dir, galaxyid + '.jpg')
		# read the image
        image = Image.open(image_path)
		# apply transform (optional)
        if self.transform is not None:
            image = self.transform(image)
		# read the true label
        label = int(self.labels_df.iloc[idx, 1])

        return image, label, int(galaxyid)

## Custom Transforms

In [None]:
def create_data_transforms(is_for_inception=False):
    """
    Create Pytorch data transforms for the GalaxyZoo datasets.
    Args:
        is_for_inception (bool): True for inception neural networks
    Outputs:
        train_transform: transform for the training data
        test_transform: transform for the testing data
    """
    if is_for_inception:
        input_size = 299
    else:
        input_size = 224

    # transforms for training data
    train_transform = transforms.Compose([transforms.CenterCrop(input_size),
                                          transforms.RandomRotation(90),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.RandomVerticalFlip(),
                                          transforms.RandomResizedCrop(input_size, scale=(0.8, 1.0), ratio=(0.99, 1.01)),
                                          transforms.ToTensor(),
                                          transforms.Normalize([0.094, 0.0815, 0.063], [0.1303, 0.11, 0.0913])])

    # transforms for validation data
    valid_transform = transforms.Compose([transforms.CenterCrop(input_size),
                                          transforms.ToTensor(),
                                          transforms.Normalize([0.094, 0.0815, 0.063], [0.1303, 0.11, 0.0913])])

    # transforms for test data
    test_transform = transforms.Compose([transforms.CenterCrop(input_size),
                                         transforms.ToTensor(),
                                         transforms.Normalize([0.094, 0.0815, 0.063], [0.1303, 0.11, 0.0913])])

    
    return train_transform, valid_transform, test_transform

## Training Function

In [None]:
def train_model(model, num_epochs, criterion, optimizer, scheduler, print_every=1, early_stop_epochs=10):
    """
    Train the model
    Args:
        model: Pytorch neural model
        num_epochs: number of epochs to train
        criterion: the loss function object
        optimizer: the optimizer
        scheduler: the learning rate decay scheduler
        print_every: print the information every X epochs
        early_stop_epochs: early stopping if the model doesn't improve after X epochs
    """
    # cache the best model
    best_model_weights = deepcopy(model.state_dict())
    # best train acc
    best_train_acc = 0.0
    # best valid acc
    best_valid_acc = 0.0
    # best epoch
    best_epoch = -1    

    # intiate dict to records the history of loss and acc
    history_dic = dict()
    history_dic['train_loss'] = []
    history_dic['train_acc'] = []
    history_dic['valid_loss'] = []
    history_dic['valid_acc'] = []
    history_dic['lr'] = []

    for epoch in range(num_epochs):
        # time of start
        epoch_start_time = time.time()

        """
        Train
        """
        model.train()

        epoch_train_cum_loss = 0.0
        epoch_train_cum_corrects = 0
        
        for images, labels, _ in train_loader:
            images = images.to(device)
            labels = labels.long().to(device)

            optimizer.zero_grad()
            
            pred_logits = model(images)
            loss = criterion(pred_logits, labels)

            _, pred_classes = torch.max(pred_logits.detach(), dim=1)
            pred_classes = pred_classes.long()

            epoch_train_cum_loss += loss.item() * images.size(0)
            epoch_train_cum_corrects += torch.sum(pred_classes==labels.data).detach().to('cpu').item()

            loss.backward()
            optimizer.step()
            
        """
        Eval
        """
        model.eval()

        epoch_valid_cum_loss = 0.0
        epoch_valid_cum_corrects = 0

        for images, labels, _ in valid_loader:
            images = images.to(device)
            labels = labels.long().to(device)

            with torch.no_grad():
                pred_logits = model(images)
                _, pred_classes = torch.max(pred_logits.detach(), dim=1)
                loss = criterion(pred_logits, labels)

                epoch_valid_cum_loss += loss.item() * images.size(0)
                epoch_valid_cum_corrects += torch.sum(pred_classes==labels.data).detach().to('cpu').item()

        ## Calculate metrics
        train_loss = epoch_train_cum_loss / len(data_train)
        train_acc = epoch_train_cum_corrects / len(data_train)
        valid_loss = epoch_valid_cum_loss / len(data_valid)
        valid_acc = epoch_valid_cum_corrects / len(data_valid)

        # update history_dic
        history_dic['train_loss'].append(train_loss)
        history_dic['train_acc'].append(train_acc)
        history_dic['valid_loss'].append(valid_loss)
        history_dic['valid_acc'].append(valid_acc)
        history_dic['lr'].append(scheduler.get_last_lr()[0])

        # check if is the best acc ever
        if valid_acc > best_valid_acc:
            best_train_acc = train_acc
            best_valid_acc = valid_acc
            best_epoch = epoch + 1
            # update the best model weights
            best_model_weights = deepcopy(model.state_dict())
            # save the best model weights to Google drive
            torch.save(model.state_dict(), os.path.join('/remote-home/cs_acmis_hby/Galaxy-Zoo-Classification/Contrast_experiment/Galaxy-Zoo-Classification/dataset_07', model_name + "_cache.pth"))

        epoch_end_time = time.time()
        epoch_time_used = epoch_end_time - epoch_start_time
        # convert epoch_time_used into mm:ss
        mm = epoch_time_used // 60
        ss = epoch_time_used % 60

        ## Print metrics
        if (epoch+1) % print_every == 0:

            # if is best valid acc
            if epoch == (best_epoch - 1):
                print("Epoch {}/{}\tTrain loss: {:.4f}\tTrain acc: {:.4f}\tValid loss: {:.4f}\tValid acc: {:.4f}\tTime: {:.0f}m {:.0f}s\t<--".format(
                    epoch+1, num_epochs, train_loss, train_acc, valid_loss, valid_acc, mm, ss))
            # not a better model
            else:
                print("Epoch {}/{}\tTrain loss: {:.4f}\tTrain acc: {:.4f}\tValid loss: {:.4f}\tValid acc: {:.4f}\tTime: {:.0f}m {:.0f}s".format(
                    epoch+1, num_epochs, train_loss, train_acc, valid_loss, valid_acc, mm, ss))
            
        ## Early stopping
        if (epoch+1) - best_epoch >= early_stop_epochs:
            print("Early stopping... (Model did not improve after {} epochs)".format(early_stop_epochs))
            break
        
        scheduler.step()
    
    # load the best weights into the model
    model.load_state_dict(best_model_weights)
    # print the best epoch
    print("Best epoch = {}, with training accuracy = {:.4f} and validation accuracy = {:.4f}".format(best_epoch, best_train_acc, best_valid_acc))

    # return the best model
    return model, history_dic

In [None]:
import os
import pandas as pd

# 原始图像数据根目录
root_dir = '/remote-home/cs_acmis_hby/Galaxy-Zoo-Classification/Contrast_experiment/gz2_images_dataset07'

# 自动提取类别目录并排序
class_names = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
class_to_idx = {name: idx for idx, name in enumerate(class_names)}

# 遍历并记录所有图像路径和标签
records = []
for class_name in class_names:
    class_dir = os.path.join(root_dir, class_name)
    for fname in os.listdir(class_dir):
        if fname.lower().endswith('.jpg'):
            galaxy_id = os.path.splitext(fname)[0]
            records.append({'galaxyID': galaxy_id, 'label1': class_to_idx[class_name], 'src_path': os.path.join(class_dir, fname)})

df_all = pd.DataFrame(records)
df_all.to_csv('gz2_all.csv', index=False)


In [None]:
from sklearn.model_selection import train_test_split

df_all = pd.read_csv('gz2_all.csv')

# Step 1: 训练集 + 临时集（0.79 vs 0.21）
df_train, df_temp = train_test_split(df_all, test_size=0.21, stratify=df_all['label1'], random_state=42)

# Step 2: 临时集分为验证集和测试集（0.01 / (0.01+0.2) ≈ 0.0476）
df_valid, df_test = train_test_split(df_temp, test_size=0.9524, stratify=df_temp['label1'], random_state=42)

# 保存
df_train[['galaxyID', 'label1']].to_csv('gz2_train.csv', index=False)
df_valid[['galaxyID', 'label1']].to_csv('gz2_valid.csv', index=False)
df_test[['galaxyID', 'label1']].to_csv('gz2_test.csv', index=False)


In [None]:
import shutil

def copy_images(df, split_name):
    dst_dir = f'images_{split_name}'
    os.makedirs(dst_dir, exist_ok=True)
    
    for _, row in df.iterrows():
        src = row['src_path']
        dst = os.path.join(dst_dir, f"{row['galaxyID']}.jpg")
        shutil.copy(src, dst)

# 读取包含 src_path 的完整数据
df_all_full = pd.read_csv('gz2_all.csv')
df_train_full = df_all_full[df_all_full['galaxyID'].isin(pd.read_csv('gz2_train.csv')['galaxyID'])]
df_valid_full = df_all_full[df_all_full['galaxyID'].isin(pd.read_csv('gz2_valid.csv')['galaxyID'])]
df_test_full = df_all_full[df_all_full['galaxyID'].isin(pd.read_csv('gz2_test.csv')['galaxyID'])]

# 复制图像
copy_images(df_train_full, 'train')
copy_images(df_valid_full, 'valid')
copy_images(df_test_full, 'test')


In [None]:
"""
Data Loader
"""
# the batch size
BATCH_SIZE = 64

# create transforms
train_transform, valid_transform, test_transform = create_data_transforms(is_for_inception=False)

# create datasets
data_train = GalaxyZooDataset('gz2_train.csv', 'images_train', train_transform)
data_valid = GalaxyZooDataset('gz2_valid.csv', 'images_valid', valid_transform)
data_test = GalaxyZooDataset('gz2_test.csv', 'images_test', test_transform)

# dataloaders
train_loader = DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(data_valid, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(data_test, batch_size=BATCH_SIZE, shuffle=True)

# check the sizes
print("**Dataloaders**")
print("Number of training data: {} ({} batches)".format(len(data_train), len(train_loader)))
print("Number of validation data: {} ({} batches)".format(len(data_valid), len(valid_loader)))
print("Number of test data: {} ({} batches)".format(len(data_test), len(test_loader)))
print("===============================")

## Hyperparameters

In [None]:
"""
Parameters
"""
PATCH_SIZE = 28
DEPTH = 12
HIDDEN_DIM = 128
K_DIM = 64
NUM_HEADS = 8

LR = 3e-4
STEP_SIZE = 5
GAMMA = 0.9
MAX_EPOCH = 200

LIN_DROPOUT = 0.1

# loss calculation for each class
#class_weights = torch.FloatTensor([1., 1., 2., 1., 1., 1., 4., 3.]).to(device)
#class_weights = torch.FloatTensor([4.6, 4.0, 20.9, 9.6, 7.7, 5.1, 26.5, 73.3]).to(device)
class_weights = torch.FloatTensor([1., 1., 1., 1., 1., 1., 1., 1.]).to(device)

## Create ViT model

In [None]:
## file name
model_name = "gz2_vit_0511_1849"

# calculate seq_len
seq_len = int((224/PATCH_SIZE)**2) + 1

## Linformer
lin = Linformer(dim=HIDDEN_DIM, seq_len=seq_len, depth=DEPTH, k=K_DIM, heads=NUM_HEADS,
                dim_head=None, one_kv_head=False, share_kv=False, reversible=False, dropout=LIN_DROPOUT)

## Vision Transformer
model = ViT(image_size=224, patch_size=PATCH_SIZE, num_classes=8, dim=HIDDEN_DIM, transformer=lin, pool='cls', channels=3)

# print out model details
print("*******[ " + model_name + " ]*******")
print("===============================")
print("patch_size = {}".format(PATCH_SIZE))
print("depth = {}".format(DEPTH))
print("hidden_dim = {}".format(HIDDEN_DIM))
print("k_dim = {}".format(K_DIM))
print("num_heads = {}".format(NUM_HEADS))
print("dropout = {}".format(LIN_DROPOUT))
print("batch_size = {}".format(BATCH_SIZE))
print("lr = {}".format(LR))
print("step_size = {}".format(STEP_SIZE))
print("gamma = {}".format(GAMMA))
print("max_epoch = {}".format(MAX_EPOCH))
print("class weights = {}".format(class_weights))
print("===============================")
print("Number of trainable parameters: {}".format(sum(param.numel() for param in model.parameters() if param.requires_grad)))
print("===============================")

In [None]:
# move to gpu
model = model.to(device)

# loss function
criterion = nn.CrossEntropyLoss(weight=class_weights)

# optimizer
optimizer = optim.Adam(model.parameters(), lr=LR)

# scheduler
scheduler = lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)

## train and return the best model
model, history_dic = train_model(model, MAX_EPOCH, criterion, optimizer, scheduler, print_every=1, early_stop_epochs=10)

## Save the best weights
torch.save(model.state_dict(), model_name + '.pth')
torch.save(model.state_dict(), os.path.join('/remote-home/cs_acmis_hby/Galaxy-Zoo-Classification/Contrast_experiment/Galaxy-Zoo-Classification/dataset_07', model_name + '.pth'))

## Convert history to dataframe
history_df = pd.DataFrame(history_dic)

## Save the history
history_df.to_csv(model_name + '_history.csv', index=False)
history_df.to_csv(os.path.join('/remote-home/cs_acmis_hby/Galaxy-Zoo-Classification/Contrast_experiment/Galaxy-Zoo-Classification/dataset_07', model_name + '_history.csv'), index=False)

In [None]:
model.load_state_dict(torch.load(os.path.join('/remote-home/cs_acmis_hby/Galaxy-Zoo-Classification/Contrast_experiment/Galaxy-Zoo-Classification/dataset_07', model_name + '_cache.pth')))

## Predict test data

In [None]:
def predict_model(model):
    """
    Predict test data
    """
    # evaluation
    model.eval()

    # empty lists for results
    y_true = []
    y_pred = []
    y_label = []

    for images, labels, galaxy_id in test_loader:
        images = images.to(device)
        labels = labels.long().to(device)

        with torch.no_grad():
            pred_logits = model(images)
            _, pred_classes = torch.max(pred_logits.detach(), dim=1)

            y_true += torch.squeeze(labels.cpu()).tolist()
            y_pred += torch.squeeze(pred_classes).tolist()
            y_label += torch.squeeze(galaxy_id.cpu()).tolist()
    
    # create a DataFrame with columns 'GalaxyID', 'class', 'predicted labels'
    predict_df = pd.DataFrame(data={'GalaxyID': y_label, 'class': y_true, 'pred': y_pred})

    return y_true, y_pred, predict_df

In [None]:
# move to gpu
model = model.to(device)

# model evaluation
y_true, y_pred, predict_df = predict_model(model)

# save predict_df
predict_df.to_csv(model_name + '_predictions.csv', index=False)
predict_df.to_csv(os.path.join('/remote-home/cs_acmis_hby/Galaxy-Zoo-Classification/Contrast_experiment/Galaxy-Zoo-Classification/dataset_07', model_name + '_predictions.csv'), index=False)

## Evaluation metrics

In [None]:
# galaxy classes
# gxy_labels = ['Round Elliptical',
#               'In-between Elliptical',
#               'Cigar-shaped Elliptical',
#               'Edge-on Spiral',
#               'Barred Spiral',
#               'Unbarred Spiral',
#               'Irregular',
#               'Merger']
gxy_labels = [
    'Barred Spiral',            # label0
    'Cigar-shaped Elliptical',  # label1
    'Edge-on',                  # label2
    'In-between Elliptical',    # label3
    'Irregular',                # label4
    'Merger',                   # label5
    'Round Elliptical',         # label6
    'Unbarred Spiral'           # label7
]


# confusion matrix
cm = confusion_matrix(y_true, y_pred, normalize='true')
cm_df = pd.DataFrame(cm, index=gxy_labels, columns=gxy_labels)

# accuracy of each class
for c in range(8):
    print("Class {}: accuracy = {:.4f} ({})".format(c, cm[c,c]/sum(cm[c,:]), gxy_labels[c]))
print("================")

# accuracy
acc = accuracy_score(y_true, y_pred)
print("Total Accuracy = {:.4f}\n".format(acc))

# recall
recall = recall_score(y_true, y_pred, average='macro')
print("Recall = {:.4f}\n".format(recall))

# f1 score
F1 = f1_score(y_true, y_pred, average='macro')
print("F1 score = {:.4f}\n".format(F1))

# plot confusion matrix
sns.set(font_scale=1.6)
fig = plt.figure(figsize=(10, 10))
sns.heatmap(cm_df, annot=True, fmt=".1%", cmap="YlGnBu", cbar=False, annot_kws={"size": 16})
plt.show()