In [14]:
import math
import torch
from torch import nn
import torchvision
from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import ToTensor
import torch.optim as optim
import torch.utils.data as data
import pickle
import os
import numpy as np
from sklearn.model_selection import train_test_split
import json
from torch.nn import functional as F
import matplotlib.pyplot as plt

from PIL import Image

In [None]:
hyper_par = {
    "patch_size": 1,  
    "embed_dim": 128,
    "num_hidden_layers": 1,
    "num_attn_heads": 1,
    "hidden_dropout_prob": 0.0,
    "attn_probs_dropout_prob": 0.0,
    "image_size": 1600,
    "num_classes": 2, 
    "num_channels": 1,
    "qkv_bias": True,
    'batch_size':32,
    'lr':0.001,
    'epochs':10,
}

BATCH_SIZE=hyper_par['batch_size']

with open("..\Ising2DFM_reSample_L40_T=All.pkl", 'rb') as f:
    data= pickle.load(f)


data=np.unpackbits(data).astype(int).reshape(-1,1600)
data=np.reshape(data,(160000,40,40))
transform = transforms.Compose(
        [transforms.ToTensor()])
for i in range(len(data)):
    data[i]=transform(np.asarray(data[i]))

data=np.expand_dims(data,1)

with open("..\Ising2DFM_reSample_L40_T=All_labels.pkl", 'rb') as f:
    labels= pickle.load(f)



labels=torch.from_numpy(labels)
data=torch.from_numpy(data).float()


dataset=torch.utils.data.TensorDataset(data,labels)

train_dataset,test_dataset=train_test_split(dataset)

train_loader=torch.utils.data.DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True)
test_loader=torch.utils.data.DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=False)



In [None]:
#The model

class Patch_Embed(nn.Module):

    def __init__(self, hyper_par):
        super().__init__()
        self.image_size = hyper_par["image_size"]
        self.patch_size = hyper_par["patch_size"]
        self.num_channels = hyper_par["num_channels"]
        self.embed_dim = hyper_par["embed_dim"]
   
        self.num_patches = (self.image_size) // (self.patch_size) ** 2

        self.projection = nn.Conv2d(self.num_channels, self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size)

    def forward(self, x):
       # (batch_size, num_channels, image_size, image_size) -> (batch_size, num_patches, embed_dim)
        x = self.projection(x)
        x = x.flatten(2).transpose(1, 2)
        return x
class Pos_Embed(nn.Module):
    def __init__(self,embed_dim,max_len=2000):
        super(Pos_Embed,self).__init__()

        self.embed_dim=embed_dim
        pe=torch.zeros(max_len,embed_dim)
        position=torch.arange(0,max_len).unsqueeze(1)
        div_term=torch.exp(torch.arange(0,embed_dim,2)*-(math.log(10000.0)/embed_dim))
        
        pe[:,0::2]=torch.sin(position*div_term)
        pe[:,1::2]=torch.cos(position*div_term)
        pe=pe.unsqueeze(0)

        self.register_buffer('pe',pe)

    def forward(self,x):
        x=x+self.pe[:,:x.size(1)]
        return x

class Embeddings(nn.Module):
        
    def __init__(self, hyper_par):
        super().__init__()
        self.hyper_par = hyper_par
        self.patch_embeddings = Patch_Embed(hyper_par)
        # Creating a learnable [CLS] token
        
        self.cls_token = nn.Parameter(torch.randn(1, 1, hyper_par["embed_dim"]))
     
        
        # self.position_embeddings = nn.Parameter(torch.randn(1, self.patch_embeddings.num_patches + 1, hyper_par["embed_dim"]))
        self.position_embeddings=Pos_Embed(hyper_par["embed_dim"])
        self.dropout = nn.Dropout(hyper_par["hidden_dropout_prob"])

    def forward(self, x):
        x = self.patch_embeddings(x)
        # x = self.patch_embeddings #for learnable encoding
        batch_size, _, _ = x.size()
        # Expanding the [CLS] token to the batch size
        
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
     
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.position_embeddings(x)
        x = self.dropout(x)
        return x


class MultiHeadAttn(nn.Module):

    def __init__(self, hyper_par):
        super().__init__()
        self.embed_dim = hyper_par["embed_dim"]
        self.num_attn_heads = hyper_par["num_attn_heads"]
  
        self.attn_head_size = self.embed_dim // self.num_attn_heads
    
        self.qkv_bias = hyper_par["qkv_bias"]

        # Creating a linear layer to project the query, key, and value
        self.qkv_projection = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=self.qkv_bias)

        self.attn_dropout = nn.Dropout(hyper_par["attn_probs_dropout_prob"])

        # Creating a linear layer to project the attn output back to the hidden size
        self.output_projection = nn.Linear(self.embed_dim, self.embed_dim)
        self.output_dropout = nn.Dropout(hyper_par["hidden_dropout_prob"])

    def forward(self, x, output_attns=False):

        # (batch_size, sequence_length, embed_dim) -> (batch_size, sequence_length, embed_dim * 3)
        qkv = self.qkv_projection(x)
    
        # (batch_size, sequence_length, embed_dim* 3) -> (batch_size, sequence_length, embed_dim)
        query, key, value = torch.chunk(qkv, 3, dim=-1)
        # Resizing the query, key, and value to (batch_size, num_attn_heads, sequence_length, attn_head_size)
        batch_size, sequence_length, _ = query.size()
        query = query.view(batch_size, sequence_length, self.num_attn_heads, self.attn_head_size).transpose(1, 2)
        key = key.view(batch_size, sequence_length, self.num_attn_heads, self.attn_head_size).transpose(1, 2)
        # key=query
        value = value.view(batch_size, sequence_length, self.num_attn_heads, self.attn_head_size).transpose(1, 2)
        # Calculating the attn scores

        attn_scores = torch.matmul(query, key.transpose(-1, -2))/ math.sqrt(self.attn_head_size)
    
        attn_probs = self.attn_dropout(nn.functional.softmax(attn_scores, dim=-1))

        # Calculating the attn output
        attn_output = torch.matmul(attn_probs, value)

        #  (batch_size, num_attn_heads, sequence_length, attn_head_size)-> (batch_size, sequence_length, embed_dim)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, sequence_length, self.embed_dim)
        
        # Project the attn output back to the embed_dim
        attn_output = self.output_dropout(self.output_projection(attn_output))
  
    
        if not output_attns:
            return (attn_output, None)
        else:
            return (attn_output, attn_probs)

class MLP(nn.Module):

    def __init__(self, hyper_par):
        super().__init__()
        self.lin_1 = nn.Linear(hyper_par["embed_dim"], hyper_par["embed_dim"])
        self.activation = nn.GELU()
        self.lin_2 = nn.Linear(hyper_par["embed_dim"], hyper_par["embed_dim"])
        self.dropout = nn.Dropout(hyper_par["hidden_dropout_prob"])

    def forward(self, x):
        x = self.lin_1(x)
        x = self.activation(x)
        x = self.lin_2(x)
        x = self.dropout(x)
        return x

class TransformerBlock(nn.Module):

    def __init__(self, hyper_par):
        super().__init__()
        self.attn = MultiHeadAttn(hyper_par)
        self.layernorm_1 = nn.LayerNorm(hyper_par["embed_dim"])
        self.mlp = MLP(hyper_par)
        self.layernorm_2 = nn.LayerNorm(hyper_par["embed_dim"])

    def forward(self, x, output_attns=False):
        
        attn_output, attn_probs = self.attn(self.layernorm_1(x), output_attns=output_attns)

        x = x + attn_output
        # Feed-forward network
        mlp_output = self.mlp(self.layernorm_2(x))
   
        x = x + mlp_output

        if not output_attns:
            return (x, None)
        else:
            return (x, attn_probs)


class Encoder(nn.Module):

    def __init__(self, hyper_par):
        super().__init__()
        # Creating a list of transformer blocks
        self.blocks = nn.ModuleList([])
        for _ in range(hyper_par["num_hidden_layers"]):
            block = TransformerBlock(hyper_par)
            self.blocks.append(block)

    def forward(self, x, output_attns=False):
        # Calculating the transformer block's output for each block
        all_attns = []
        for block in self.blocks:
            x, attn_probs = block(x, output_attns=output_attns)
            if output_attns:
                all_attns.append(attn_probs)
        # Returning the encoder's output and the attn probabilities
        if not output_attns:
            return (x, None)
        else:
            return (x, all_attns)

class ViT(nn.Module):

    def __init__(self, hyper_par):
        super().__init__()
        self.hyper_par = hyper_par
        self.image_size = hyper_par["image_size"]
        self.embed_dim = hyper_par["embed_dim"]
        self.num_classes = hyper_par["num_classes"]
     
        self.embedding = Embeddings(hyper_par)

        self.encoder = Encoder(hyper_par)
        # Creating a linear layer to project the encoder's output to the number of classes
        self.classifier = nn.Linear(self.embed_dim, self.num_classes)

    def forward(self, x, output_attns=False):
 
        embedding_output = self.embedding(x)

        encoder_output, all_attns = self.encoder(embedding_output, output_attns=output_attns)
        # Calculating the logits, take the [CLS] token's output as features for classification
        logits = self.classifier(encoder_output[:, 0, :])

        if not output_attns:
            return (logits, None)
        else:
            return (logits, all_attns)

In [None]:
#Functions to save the model and visualize the results

def save_experiment(experiment_name, hyper_par, model, train_losses, test_losses, accuracies, base_dir="experiments"):
    outdir = os.path.join(base_dir, experiment_name)
    os.makedirs(outdir, exist_ok=True)
    
    # Save the hyper_par
    hyper_parfile = os.path.join(outdir, 'hyper_par.json')
    with open(hyper_parfile, 'w') as f:
        json.dump(hyper_par, f, sort_keys=True, indent=4)
    
    # Save the metrics
    jsonfile = os.path.join(outdir, 'metrics.json')
    with open(jsonfile, 'w') as f:
        data = {
            'train_losses': train_losses,
            'test_losses': test_losses,
            'accuracies': accuracies,
        }
        json.dump(data, f, sort_keys=True, indent=4)
    
    # Save the model
    save_checkpoint(experiment_name, model, "final", base_dir=base_dir)

def save_checkpoint(experiment_name, model, epoch, base_dir="experiments"):
    outdir = os.path.join(base_dir, experiment_name)
    os.makedirs(outdir, exist_ok=True)
    cpfile = os.path.join(outdir, f'model_{epoch}.pt')
    torch.save(model.state_dict(), cpfile)

def load_experiment(experiment_name, checkpoint_name="model_final.pt", base_dir="experiments"):
    outdir = os.path.join(base_dir, experiment_name)
    # Load the hyper_par
    hyper_parfile = os.path.join(outdir, 'hyper_par.json')
    with open(hyper_parfile, 'r') as f:
        hyper_par = json.load(f)
    # Load the metrics
    jsonfile = os.path.join(outdir, 'metrics.json')
    with open(jsonfile, 'r') as f:
        data = json.load(f)
    train_losses = data['train_losses']
    test_losses = data['test_losses']
    accuracies = data['accuracies']
    # Load the model
    model = ViT(hyper_par)
    cpfile = os.path.join(outdir, checkpoint_name)
    model.load_state_dict(torch.load(cpfile))
    return hyper_par, model, train_losses, test_losses, accuracies



shape=40
@torch.no_grad()
def visualize_attn(model, output=None):
    seed=32
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    model.eval()
    # Load random images
    num_images = 30
    testset = test_dataset
    classes = ('Disordered Phase','Ordered Phase')


    indices = torch.randperm(len(testset))[:num_images]
    raw_images = [np.asarray(testset[i][0]) for i in indices]
   
    labels = [testset[i][1] for i in indices]
    # Convert the images to tensors
    test_transform = transforms.Compose(
        [transforms.ToTensor()])
    images = torch.stack([test_transform(image) for image in raw_images])
    images = np.reshape(images,(num_images,1,shape,shape))

   
    device= "cuda" if torch.cuda.is_available() else "cpu"
    images = images.to(device)
    model = model.to(device)
    # Get the attn maps from the last block
    logits, attn_maps = model(images, output_attns=True)
    # Get the predictions
    predictions = torch.argmax(logits, dim=1)
    # Concatenate the attn maps from all blocks
    attn_maps = torch.cat(attn_maps, dim=1)
    # selecting  the attn maps of the CLS token
    attn_maps = attn_maps[:, :, 0, 1:]
    # Then average the attn maps of the CLS token over all the heads
    attn_maps = attn_maps.mean(dim=1)

    num_patches = attn_maps.size(-1)
    size = int(math.sqrt(num_patches))
    attn_maps = attn_maps.view(-1, size, size)
    # Resizing the map to the size of the image
    attn_maps = attn_maps.unsqueeze(1)
    attn_maps = F.interpolate(attn_maps, size=(40, 40), mode='bilinear', align_corners=False)
    attn_maps = attn_maps.squeeze(1)

    # Ploting the images and the attn maps
    fig = plt.figure(figsize=(40, 20))
    mask = np.concatenate([np.ones((shape, shape)), np.zeros((shape,shape))], axis=1)
    for i in range(num_images):
        raw_images[i] = np.reshape(raw_images[i],(shape,shape,1))
        ax = fig.add_subplot(6, 5, i+1, xticks=[], yticks=[])
        img = np.concatenate((raw_images[i][:,:,0], raw_images[i][:,:,0]), axis=1)
        ax.imshow(img,cmap='gray')
        # Mask out the attn map of the left image
        extended_attn_map = np.concatenate((np.zeros((shape, shape)), attn_maps[i].cpu()), axis=1)
        extended_attn_map = np.ma.masked_where(mask==1, extended_attn_map)
        ax.imshow(extended_attn_map, alpha=0.5, cmap='jet')
        # Show the ground truth and the prediction
        gt = classes[labels[i]]
        pred = classes[predictions[i]]
        ax.set_title(f"{gt}                       attn map:{pred}", color=("green" if gt==pred else "red"))
    if output is not None:
        plt.savefig(output)
    plt.show()
    

def plot_image_patches(image, patch_size, stride,output_file=None):

    height, width,_ = image.shape
    patches = []

    # Extract patches from the image
    for y in range(0, height - patch_size + 1, stride):
        for x in range(0, width - patch_size + 1, stride):
            patch = image[y:y+patch_size, x:x+patch_size]
            patches.append(patch)

    num_patches = len(patches)

    # Determine the number of rows and columns for subplots
    num_cols = int(np.ceil(np.sqrt(num_patches)))
    num_rows = int(np.ceil(num_patches / num_cols))

    # Create subplots to display the patches
    fig, axs = plt.subplots(num_rows, num_cols, figsize=(10, 10))
  
    for i, ax in enumerate(axs.flat):
        if i < num_patches:
            ax.imshow(patches[i], cmap='jet',vmin=np.min(np.asarray(image)), vmax=np.max(np.asarray(image)))
#             ax.imshow(patches[i], cmap='jet')

            ax.axis('off')

            # Writing pixel values on each patch
            height, width,_ = patches[i].shape
            for y in range(height):
                for x in range(width):
                    ax.text(x, y,str(np.round(float(10000*patches[i][y, x]),2)), color='black', fontsize=6,
                            ha='center', va='center')

    plt.tight_layout()
    

    # Save the image as PNG if the output file path is provided
    if output_file:
        plt.savefig(output_file, dpi=300,bbox_inches='tight')

    plt.show()



@torch.no_grad()  
def  patch_attn(model):
    
    model.eval()
    # Load random images
    num_images = 1
    testset = test_dataset
    classes = ('Disordered Phase','Ordered Phase')


    indices = torch.randperm(len(testset))[:num_images]
    raw_images = np.asarray(testset[indices][0]) 
   
    label = testset[indices][1]
    label = classes[label]
    # Convert the images to tensors
    test_transform = transforms.Compose(
        [transforms.ToTensor()])
    images = test_transform(raw_images)
    images = np.reshape(images,(num_images,1,40,40))

    device= "cuda" if torch.cuda.is_available() else "cpu"
    images = images.to(device)
    model = model.to(device)
    # Get the attention maps from the last block
    logits, attention_maps = model(images, output_attentions=True)

    predictions = torch.argmax(logits, dim=1)
    
    attention_maps = torch.cat(attention_maps, dim=1)
    # selecting only the attention maps of the CLS token
    attention_maps = attention_maps[:, :, 0, 1:]
    # Then average the attention maps of the CLS token over all the heads
    attention_maps = attention_maps.mean(dim=1)

    num_patches = attention_maps.size(-1)
    size = int(math.sqrt(num_patches))
    attention_maps = attention_maps.view(-1, size, size)


    attn_map=np.reshape(attention_maps.cpu(),(40,40,1))
    plot_image_patches(np.assarray(attn_map),20,20,label,"patch_attn.png")
    gt = classes[labels]
    pred = classes[predictions]
    print(f"Ground Truth:{gt}",f"Prediction: {pred}")

In [None]:
class Trainer:

    def __init__(self, model, optimizer, loss_fn, exp_name, device):
        self.model = model.to(device)
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.exp_name = exp_name
        self.device = device

    def train(self, trainloader, testloader, epochs, save_model_every_n_epochs=0):
   
        # Keeping track of the losses and accuracies
        train_losses, test_losses, accuracies = [], [], []
        # Training the model
        for i in range(epochs):
            train_loss = self.train_epoch(trainloader)
            accuracy, test_loss = self.evaluate(testloader)
            train_losses.append(train_loss)
            test_losses.append(test_loss)
            accuracies.append(accuracy)
            print(f"Epoch: {i+1}, Train loss: {train_loss:.4f}, Test loss: {test_loss:.4f}, Accuracy: {accuracy:.4f}")
            if save_model_every_n_epochs > 0 and (i+1) % save_model_every_n_epochs == 0 and i+1 != epochs:
                print('\tSave checkpoint at epoch', i+1)
                save_checkpoint(self.exp_name, self.model, i+1)
        # Save the experiment
        save_experiment(self.exp_name, hyper_par, self.model, train_losses, test_losses, accuracies)

    def train_epoch(self, trainloader):
        
        self.model.train()
        total_loss = 0
        for batch in trainloader:
            # Move the batch to the device
            batch = [t.to(self.device) for t in batch]
            images, labels = batch
            # Zero the gradients
            self.optimizer.zero_grad()
            # Calculate the loss
            loss = self.loss_fn(self.model(images)[0], labels)
            # Backpropagate the loss
            loss.backward()
            # Update the model's parameters
            self.optimizer.step()
            total_loss += loss.item() * len(images)
        return total_loss / len(trainloader.dataset)

    @torch.no_grad()
    def evaluate(self, testloader):
        self.model.eval()
        total_loss = 0
        correct = 0
        with torch.no_grad():
            for batch in testloader:
                # Move the batch to the device
                batch = [t.to(self.device) for t in batch]
                images, labels = batch
                
                # Get predictions
                logits, _ = self.model(images)

                # Calculate the loss
                loss = self.loss_fn(logits, labels)
                total_loss += loss.item() * len(images)

                # Calculate the accuracy
                predictions = torch.argmax(logits, dim=1)
                correct += torch.sum(predictions == labels).item()
        accuracy = correct / len(testloader.dataset)
        avg_loss = total_loss / len(testloader.dataset)
        return accuracy, avg_loss


def main():
  
    batch_size = hyper_par['batch_size']
    epochs = hyper_par['epochs']
    lr = hyper_par['lr']
    device = "cuda" if torch.cuda.is_available() else "cpu"
    save_model_every_n_epochs = 5
    # Load the dataset
    trainloader,testloader=train_loader,test_loader
    # Create the model, optimizer, loss function and trainer
    model = ViT(hyper_par)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    trainer = Trainer(model, optimizer, loss_fn, 'my_exp', device=device)
    trainer.train(trainloader, testloader, epochs, save_model_every_n_epochs=save_model_every_n_epochs)

if __name__ == "__main__":
    main()

In [None]:
hyper_par, model, train_losses, test_losses, accuracies = load_experiment("../experiments/my_exp")


In [None]:
visualize_attn(model, "attn.png")

In [None]:
patch_attn(model)

In [None]:

import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
ax1.plot(train_losses, label="Train loss")
ax1.plot(test_losses, label="Test loss")

ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax1.legend()
ax2.plot(accuracies)
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Accuracy")
plt.savefig("metrics.png")
plt.show()