# Classification of Simulated Lenses in DESI using Visual Transformer

Author: Anthony LaBarca, modified by Delaney Cummins

Date: 2023-07-24

Version: 1.0

License: ---

Description: This script will adapt the ViT model to DESI spectral data of resolution 3600

In [1]:
import json
import os
import sys
import pickle

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchsummary

from torch.utils.data import DataLoader, Dataset

def is_interactive():
    import __main__ as main
    return not hasattr(main, "__file__")


if is_interactive():
    from tqdm.notebook import tqdm, trange
else:
    from tqdm import tqdm, trange

# Imports
torch.manual_seed(0)
np.random.seed(0)

# Synthetic Data Location 
#filepath to pickles
# filepath='/global/homes/a/alabarca/DESI-Timedomain/simulated_preprocessing/old/'


## ViT Code (Don't want to deal with any imports atm)

In [2]:
'''
EXAMPLE OUTPUT OF TRANFORMER WITH 100 PATCHES
torch.Size([1, 100, 36])                                                                 # Original 
torch.Size([1, 100, 25])                                                                 # Move to hidden Dimension
torch.Size([1, 101, 25])                                                                 # Add Classification Token
torch.Size([1, 101, 25])                                                                 # Add positional embeddings
torch.Size([1, 101, 25])                                                                 # 1 ViT Block
torch.Size([1, 101, 25])                                                                 # 2 ViT Block
torch.Size([1, 101, 25])                                                                 # 3 Vit Block
torch.Size([1, 101, 25])                                                                 # 4 ViT Block
torch.Size([1, 101, 25])                                                                 # 5 ViT Block
torch.Size([1, 25])                                                                      # Take only the classification Token
tensor([[0.4370, 0.1115, 0.0393, 0.0505, 0.0449, 0.3168]], grad_fn=<SoftmaxBackward>)    # Final Output (Predictions)
'''
''

''

## Training

In [3]:
print("Testing ViT Training")
print("-"* 20)
batch_size=256
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"DEVICE IN USE {device}")

num_epochs= 60
learning_rate=0.001
model_suff='lens-3600'
plot=False
retrain = True
continue_train = True
previous_epoch = 0

 # model_name
model_name = f"vit_model_{model_suff}"
print("Model name: ", model_name)
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))

Testing ViT Training
--------------------
DEVICE IN USE cpu
Model name:  vit_model_V1_big


### Import Data and turn into dataloaders

In [4]:
# Load parameters and data
with open(f"split/parameters.json", "r") as f:
    model_params = json.load(f)

patch_size = model_params['patch_size']
patch_num = model_params['patch_num']
spectra_length = model_params['spectra_length']

x_train = torch.Tensor(np.load('split/V1_xtrain.npy'))
y_train = torch.Tensor(np.load('split/V1_ytrain.npy')).long().squeeze()
x_test = torch.Tensor(np.load('split/V1_xtest.npy'))
y_test = torch.Tensor(np.load('split/V1_ytest.npy')).long().squeeze()

print("Training set: ", x_train.shape)
print("Test set: ", x_test.shape)

Training set:  torch.Size([15514, 1, 7800])
Test set:  torch.Size([5172, 1, 7800])


In [5]:
# Create dataloaders
def get_dataloaders(x_train, x_test, y_train, y_test, batch_size, device):
    class SpectraDataset(Dataset):
        def __init__(self, x, y):
            self.x = x
            self.y = y
        
        def __len__(self):
            return len(self.x)
        
        def __getitem__(self, idx):
            return self.x[idx], self.y[idx]
    
    train_dataset = SpectraDataset(x_train, y_train)
    test_dataset = SpectraDataset(x_test, y_test)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=device)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, pin_memory=device)
    
    return train_loader, test_loader



train_loader, test_loader = get_dataloaders(x_train, x_test, y_train, y_test, batch_size, device)

### Create Model 

In [6]:
# Helper Methods
def patchify(spectra: torch.Tensor, n_patches: int) -> torch.Tensor:
    """
    spectra: 1D spectra: torch.Tensor of shape (N, 1, len_spectrum)
    n_patches: number of patches to break the spectra into (must be a factor of len_spectrum)

    return: patches of the spectra: torch.Tensor of shape (N, n_patches, len_spectrum // n_patches)
    """

    n, _, l_spectrum = spectra.shape

    # create patches
    patch_size = l_spectrum // n_patches
    patches = torch.zeros(n, n_patches, l_spectrum // n_patches)
    for idx, spectrum in enumerate(spectra):
        for i in range(n_patches):
            patch = spectrum[:, i * patch_size: (i + 1) * patch_size]
            patches[idx, i] = patch

    return patches

def positional_embedding(i, j, d):
    """
    i: tensor index
    j: embedding dimension

    return: positional embedding for i, j
    """

    if j % 2 == 0:
        return np.sin(i / (10000 ** (j / d)))
    return np.cos(i / (10000 ** ((j - 1) / d)))


def get_positional_embeddings(sequence_length: int, d) -> torch.Tensor:
    """
    sequence_length: length of sequence
    d: embedding dimension

    return: positional embeddings for sequence of length sequence_length
    """

    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = positional_embedding(i, j, d)

    return result
    
class MSA(nn.Module):
    """
    Multi-Head Self-Attention
    """
    
    def __init__(self, d, n_heads=2):
        super(MSA, self).__init__()
        self.d = d
        self.n_heads = n_heads
        
        assert d % n_heads == 0, f"Cannot divide dimension {d} into {n_heads} heads"
        
        d_head = int(d / n_heads)
        self.q_mappings = nn.ModuleList(
            [nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.k_mappings = nn.ModuleList(
            [nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.v_mappings = nn.ModuleList(
            [nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.d_head = d_head
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self, sequences):
        """
        Sequences have shapes (N, seq_length, token_dim)
        We must transform to shape (N, seq_length, n_heads, token_dim / n_heads)
        and concatenate back into (N, seq_length, token_dim)
        """
        result = []
        for sequence in sequences:
            seq_result = []
            for head in range(self.n_heads):
                q_mapping = self.q_mappings[head]
                k_mapping = self.k_mappings[head]
                v_mapping = self.v_mappings[head]
                
                seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]
                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)
                
                attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
                seq_result.append(attention @ v)
            result.append(torch.hstack(seq_result))
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])


class ViTBlock(nn.Module):
    """
    Transformer Encoder Block
    """
    
    def __init__(self, hidden_d, n_heads, mlp_ratio=4):
        super(ViTBlock, self).__init__()
        self.hidden_d = hidden_d
        self.n_heads = n_heads
        
        self.norm1 = nn.LayerNorm(hidden_d)
        self.mhsa = MSA(hidden_d, n_heads)
        self.norm2 = nn.LayerNorm(hidden_d)
        self.mlp = nn.Sequential(nn.Linear(hidden_d, mlp_ratio * hidden_d), nn.GELU(),
                                 nn.Linear(mlp_ratio * hidden_d, hidden_d))
    
    def forward(self, x):
        """
        Encoder1 will normalize input, pass through MSA,
        add residual connection

        Encoder2 will normalize encoder1, pass through MLP
        """
        encoder1 = x + self.mhsa(self.norm1(x))
        encoder2 = encoder1 + self.mlp(self.norm2(encoder1))
        return encoder2


class ViT(nn.Module):
    def __init__(self, cl=(1, 1024), n_patches=64, n_blocks=2, hidden_d=8, n_heads=2, out_d=10, device = None):
        super(ViT, self).__init__()
        
        # If device is provided, use that. 
        # HOwever, if CUDA is availible, that is the default device
        if device is not None:
            self.device = device
        else:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"ViT IS NOW IN {self.device}")
        
        self.cl = cl  # (channels, length)
        self.n_patches = n_patches
        self.n_blocks = n_blocks
        self.n_heads = n_heads
        self.hidden_d = hidden_d
        
        assert cl[1] % n_patches == 0, "Image length must be divisible by n_patches"
        self.patch_size = cl[1] // n_patches
        
        # Linear mapping of patches to hidden dimension
        self.input_d = int(cl[0] * self.patch_size)
        self.linear_mapper = nn.Linear(self.input_d, self.hidden_d).to(self.device)
        
        # Classification Token
        self.class_token = nn.Parameter(torch.rand(1, self.hidden_d).to(self.device))
        
        # Positional embeddings
        self.pos_embed = nn.Parameter(get_positional_embeddings(
            n_patches + 1, self.hidden_d).clone().detach())
        self.pos_embed.requires_grad = False
        
        # Transformer Encoder
        self.blocks = nn.ModuleList(
            [ViTBlock(hidden_d, n_heads) for _ in range(n_blocks)]).to(self.device)
        
        # Classification mlp
        self.mlp = nn.Sequential(
            nn.Linear(self.hidden_d, out_d), nn.Softmax(dim=-1)).to(self.device)
        
        
        self.apply(self._init_weights)
    
    
    def _init_weights(self, module):
        if isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=1.0)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    
    def forward(self, images):
        # Creating patches
        n, _, _ = images.shape
        patches = patchify(images, self.n_patches).to(self.device)
        # Linear tokenization --> map vector corresponding to each patch to hidden dimension
        image_tokens = self.linear_mapper(patches)
        # Adding classification
        tokens = torch.stack([torch.vstack(
            (self.class_token, image_tokens[i])) for i in range(len(image_tokens))])
        # Adding positional embeddings
        pos_embed = self.pos_embed.repeat(n, 1, 1)
        out = tokens + pos_embed
        
        for block in self.blocks:
            out = block(out)
        
        # For classification, we take the first token
        out = out[:, 0]
        
        return self.mlp(out)
    
    def saveparams(self, model_name):
        dict = {'cl': self.cl, 'patches':self.n_patches, 'n_blocks':self.n_blocks, 'n_heads':self.n_heads, 'hidden_d':self.hidden_d}
        with open(f'{model_name}_parameters.json', 'w') as f:
            json.dump(dict, f)

# Create Old Model 
# model = ViT(cl=(1, spectra_length), n_patches=patch_num, n_blocks=4, hidden_d = patch_num // 4, n_heads = 5, out_d = 6).to(device)

In [7]:
# Create New Model
model = ViT(cl=(1, spectra_length), n_patches=patch_num, n_blocks=4, hidden_d = patch_size * 2, n_heads = 12, out_d = 2).to(device)

model.saveparams(model_name)



# Optimizer and Loss Functions
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Initialize History Tracking
model_history = {}
# if continue_train and previous_epoch > 0:
#     with open(f'{model_name}_history.json', "r") as f:
#         model_history = json.load(f)
#         print("Loaded")
#     model.load_state_dict(torch.load(f'model/{model_name}_epoch{previous_epoch}.pt'))
#     print("Loaded")

#torchsummary.summary(model, (1, spectra_length))

ViT IS NOW IN cpu


In [8]:
def train_model(train_loader, model, device, epoch, criterion, optimizer):
    train_loss = 0.0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):
        x, y = batch
        x = x.type(torch.LongTensor) 
        y = y.type(torch.LongTensor) 
        x, y = x.to(device), y.to(device)
        
        y_hat = model(x)
        loss = criterion(y_hat, y)
        
        train_loss += loss.detach().cpu().item() / len(train_loader)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return train_loss

def test_model(test_loader: DataLoader, model: nn.Module, device: torch.device, criterion, plot: bool = True):
    """
    Test the model on the test set 

    Parameters: 
    -----------
    test_loader: DataLoader 
    """
    model.train().to(device)
    
    correct, total = 0, 0
    test_loss = 0.0
    for batch in tqdm(test_loader, desc="Testing"):
        x, y = batch
        x = x.type(torch.LongTensor) 
        y = y.type(torch.LongTensor) 
        x, y = x.to(device), y.to(device)
        y_hat = model(x)
        # force printout
        loss = criterion(y_hat, y)
        test_loss += loss.detach().cpu().item() / len(test_loader)
        
        correct += torch.sum(torch.argmax(y_hat, dim=1)
                             == y).detach().cpu().item()
        total += len(x)
    print(f"Test loss: {test_loss:.2f}")
    print(f"Test accuracy: {correct / total * 100:.2f}%")

    # If you don't want Predictions, just return and get out 
    # if plot or y_hat is None:
    return correct / total * 100, test_loss
    
#     # If you want a figure, make some predictions
#     # get first batch from dataset, and plot the first 10 predictions made by the model
#     data = next(iter(test_loader))
#     x, y = data
#     x = x.type(torch.LongTensor) 
#     y = y.type(torch.LongTensor) 
#     x, y = x.to(device), y.to(device)
#     y_hat = model(x)

#     fig, ax = plt.subplots(10, 1, figsize=(10, 20), sharex="all")
#     for i in range(10):
#         print(f"Prediction: {torch.argmax(y_hat[i])}, Label: {y[i]}")
#         ax[i].plot(x[i][0])
#         ax[i].set_title(
#             f"Prediction: {torch.argmax(y_hat[i])}, Label: {y[i]}")
#     fig.supylabel("Intensity")
#     fig.supxlabel("Wavelength")
#     fig.tight_layout()
#     return correct / total * 100, test_loss
    

In [9]:
if not os.path.exists(model_name) or retrain:
    training_loss, test_loss, test_acc = [], [], []
    if continue_train and previous_epoch > 0:
        training_loss = model_history["training_loss"]
        test_loss = model_history["test_loss"]
        test_acc = model_history["test_acc"]
    print(len(training_loss))
    # Training loop
    for epoch in trange(num_epochs, desc="Training"):
        # Train model
        model.train().to(device)

        train_loss = train_model(train_loader, model, device, epoch, criterion, optimizer)
        print(f"Epoch {previous_epoch + epoch + 1}/{previous_epoch + num_epochs} loss: {train_loss:.2f}")

        # Test model
        model.eval()
        with torch.no_grad():
            test_acc_curr, test_loss_curr = test_model(test_loader, model, device, criterion, plot=False)

        # Bookkeeping
        training_loss.append(train_loss)
        test_loss.append(test_loss_curr)
        test_acc.append(test_acc_curr)
        # Save model
        model_history = {"training_loss": training_loss, "test_loss": test_loss, "test_acc": test_acc}
        # Save model and history
        torch.save(model.state_dict(), f"model/{model_name}_epoch{previous_epoch + epoch}.pt")
        
        with open(f"{model_name}_history.json", "w") as f:
            json.dump(model_history, f)

    if plot:
        with torch.no_grad():
            test_model(test_loader, model, device, criterion, plot=plot)

0


Training:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1 in training:   0%|          | 0/61 [00:00<?, ?it/s]

Epoch 1/100 loss: 0.85


Testing:   0%|          | 0/21 [00:00<?, ?it/s]

Test loss: 0.84
Test accuracy: 46.95%


Epoch 2 in training:   0%|          | 0/61 [00:00<?, ?it/s]

Epoch 2/100 loss: 0.85


Testing:   0%|          | 0/21 [00:00<?, ?it/s]

Test loss: 0.84
Test accuracy: 46.95%


Epoch 3 in training:   0%|          | 0/61 [00:00<?, ?it/s]

Epoch 3/100 loss: 0.86


Testing:   0%|          | 0/21 [00:00<?, ?it/s]

Test loss: 0.84
Test accuracy: 46.95%


Epoch 4 in training:   0%|          | 0/61 [00:00<?, ?it/s]

Epoch 4/100 loss: 0.86


Testing:   0%|          | 0/21 [00:00<?, ?it/s]

Test loss: 0.84
Test accuracy: 46.95%


Epoch 5 in training:   0%|          | 0/61 [00:00<?, ?it/s]

Epoch 5/100 loss: 0.85


Testing:   0%|          | 0/21 [00:00<?, ?it/s]

Test loss: 0.84
Test accuracy: 46.95%


Epoch 6 in training:   0%|          | 0/61 [00:00<?, ?it/s]

Epoch 6/100 loss: 0.86


Testing:   0%|          | 0/21 [00:00<?, ?it/s]

Test loss: 0.84
Test accuracy: 46.95%


Epoch 7 in training:   0%|          | 0/61 [00:00<?, ?it/s]

Epoch 7/100 loss: 0.86


Testing:   0%|          | 0/21 [00:00<?, ?it/s]

Test loss: 0.84
Test accuracy: 46.95%


Epoch 8 in training:   0%|          | 0/61 [00:00<?, ?it/s]

Epoch 8/100 loss: 0.86


Testing:   0%|          | 0/21 [00:00<?, ?it/s]

Test loss: 0.84
Test accuracy: 46.95%


Epoch 9 in training:   0%|          | 0/61 [00:00<?, ?it/s]

Epoch 9/100 loss: 0.86


Testing:   0%|          | 0/21 [00:00<?, ?it/s]

Test loss: 0.84
Test accuracy: 46.95%


Epoch 10 in training:   0%|          | 0/61 [00:00<?, ?it/s]

KeyboardInterrupt: 

# History

In [None]:
def plot_history(history: dict, figure=None):
    """
    Plot training history

    Parameters
    ----------
    history: dict
        Dictionary with training history
    figure: tuple
        Tuple of (fig, ax) to plot on

    Returns
    -------
    fig, ax
        Figure and axis
    """
    # plot on one axis, with two y axes (one for loss, one for accuracy)
    
    # figure is either None, or a tuple of (fig, ax)
    if figure is None:
        fig, ax1 = plt.subplots(1, 1, figsize=(10, 5))
        assert isinstance(ax1, plt.Axes), "Figure must be a tuple of (fig, ax)"
        ax2 = ax1.twinx()
    else:
        fig, ax = tuple(figure)
        if isinstance(ax, np.ndarray):
            ax1 = ax[0]
            ax2 = ax[1]
        else:
            assert isinstance(
                ax, plt.Axes), "Figure must be a tuple of (fig, ax)"
            ax1 = ax
            ax2 = ax1.twinx()
    
    assert isinstance(ax1, plt.Axes), "Figure must be a tuple of (fig, ax)"
    assert isinstance(ax2, plt.Axes), "Figure must be a tuple of (fig, ax)"
    
    # List of colors to be used
    colors = ["tab:blue", "tab:orange", "tab:green", "tab:red", "tab:purple",
              "tab:brown", "tab:pink", "tab:gray", "tab:olive", "tab:cyan"]
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Loss")
    ax1.plot(history["training_loss"], label="Training loss", color=colors[0])
    ax1.plot(history["test_loss"], label="Test loss", color=colors[1])
    ax1.legend()
    
    ax2.set_ylabel("Accuracy")
    # Use the 3rd color in the color cycle
    ax2.plot(history["test_acc"], label="Test accuracy", color=colors[2])
    ax2.set_xlabel("Epoch")
    ax2.set_ylabel("Accuracy")
    ax2.legend()
    ax = (ax1, ax2)
    fig.tight_layout()
    
    return fig, ax

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
print(type(ax))
history_figure2, hist_ax = plot_history(model_history, figure = (fig, ax))
history_figure2.suptitle("Loss and Accuracy of ViT")
history_figure2.tight_layout()
history_figure2.savefig(f'history-{model_name}.png', facecolor='white', transparent=False, edgecolor='none')

In [None]:
def get_positional_embeddings(sequence_length: int, d) -> torch.Tensor:
    """
    sequence_length: length of sequence
    d: embedding dimension

    return: positional embeddings for sequence of length sequence_length
    """
    
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = positional_embedding(i, j, d)
    
    return result

test = nn.Parameter(get_positional_embeddings(100 + 1, 300).clone().detach())

pos_embed = test.repeat(1, 1, 1)
fig, ax = plt.subplots(1,1, figsize=(4,4))
print(pos_embed[0].detach().numpy().shape)
ax.imshow(pos_embed[0].detach().numpy())
ax.set_title("Positional Embeddings")
ax.set_xlabel("Hidden Dimension")
ax.set_ylabel("Patch #")
fig.tight_layout()
plt.savefig("embeddings.png", dpi=300)