Transformer (Multi-head Attention Class)
EC523 Project
Team 2

In [1]:
# Import Libraries

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score
from torch.utils.data import DataLoader, TensorDataset


Transformer Class

In [2]:
class Transformer(nn.Module):
  def __init__(self, input_dim, hidden_dim, num_heads, output_dim):
    super(Transformer, self).__init__()
    self.input_dim = input_dim
    self.hidden_dim = hidden_dim
    self.num_heads = num_heads
    self.output_dim = output_dim

    self.multihead_attn = nn.MultiheadAttention(embed_dim=self.input_dim, num_heads=self.num_heads, batch_first=True)


  def forward(self,x):
    attn_output, attn_weights = self.multihead_attn(x,x,x)

    return attn_output, attn_weights


Generate Random X and Y

In [10]:
Width, Height = 10, 10
Images = 30

X = np.random.rand(Images, Width, Height).reshape(-1, Width * Height)
Y = np.random.rand(Images, Width, Height).reshape(-1, Width * Height)
print(f'Shape of Flattened Image:', X.shape)

X_tensor = torch.tensor(X, dtype=torch.float32)
Y_tensor = torch.tensor(Y, dtype=torch.float32)

dataset = TensorDataset(X_tensor, Y_tensor)
dataloader = DataLoader(dataset, batch_size=20, shuffle=True)

Shape of Flattened Image: (30, 100)


In [3]:
import os
import numpy as np
from PIL import Image 
import matplotlib.pyplot as plt

#The following code loads in the brightfield for the training images in 2b from the storage on the SCC and stores them in a 3D array.
#Warning! This code only works for images that are 256 by 256 right now. It is hardcoded

#/\/\/\/\/\/\/\\/\/\/
#Training BrightField
#/\/\/\/\/\/\/\\/\/\/

# Define the folder containing the TIFF files
folder_path = '../../../projectnb/ec523kb/projects/teams_Fall_2024/Team_2/bacteria_counting/Data/2b/DeepBacs_Data_Segmentation_Staph_Aureus_dataset/brightfield_dataset/train/patches/brightfield'

# List all TIFF files in the folder
tiff_files = [f for f in os.listdir(folder_path) if f.endswith('.tif') or f.endswith('.tiff')]

# Initialize a list to store images
images = []

# Loop through and load each TIFF file
for file_name in tiff_files:
    file_path = os.path.join(folder_path, file_name)

    # Load the image and convert to a NumPy array
    image = np.array(Image.open(file_path))

    # Ensure the image is 256x256, or resize if necessary
    if image.shape != (256, 256):
        image = np.array(Image.open(file_path).resize((256, 256)))

    # Append the image to the list
    # images.append(image)
    
    #plt.imshow(image)  # Use cmap='gray' for grayscale images
    #plt.axis('off')  # Turn off the axis
    #plt.show()
    
    # Split the image into 4 quarters
    top_left = image[:128, :128]  # Top-left quarter
    top_right = image[:128, 128:]  # Top-right quarter
    bottom_left = image[128:, :128]  # Bottom-left quarter
    bottom_right = image[128:, 128:]  # Bottom-right quarter

    # Stack the quarters to form a (4, 64, 64) array
    quarters_image = np.stack([top_left, top_right, bottom_left, bottom_right])

    # Append the quarters image to the list
    images.append(quarters_image)
    

# Convert the list of images into a 3D NumPy array
image_stack = np.stack(images)

image_stack = image_stack.reshape(-1, 128, 128)

#print(image_stack.shape)  # Should be (n_images, 256, 256)


#/\/\/\/\/\/\/\\/\/\/
#Training masks
#/\/\/\/\/\/\/\\/\/\/

# Define the folder containing the TIFF files
folder_path = '../../../projectnb/ec523kb/projects/teams_Fall_2024/Team_2/bacteria_counting/Data/2b/DeepBacs_Data_Segmentation_Staph_Aureus_dataset/brightfield_dataset/train/patches/masks'

# List all TIFF files in the folder
tiff_files = [f for f in os.listdir(folder_path) if f.endswith('.tif') or f.endswith('.tiff')]

# Initialize a list to store images
images = []

# Loop through and load each TIFF file
for file_name in tiff_files:
    file_path = os.path.join(folder_path, file_name)

    # Load the image and convert to a NumPy array
    image = np.array(Image.open(file_path))

    # Ensure the image is 256x256, or resize if necessary
    if image.shape != (256, 256):
        image = np.array(Image.open(file_path).resize((256, 256)))

    # Append the image to the list
    # images.append(image)
    
    #plt.imshow(image)  # Use cmap='gray' for grayscale images
    #plt.axis('off')  # Turn off the axis
    #plt.show()
    
    # Split the image into 4 quarters
    top_left = image[:128, :128]  # Top-left quarter
    top_right = image[:128, 128:]  # Top-right quarter
    bottom_left = image[128:, :128]  # Bottom-left quarter
    bottom_right = image[128:, 128:]  # Bottom-right quarter

    # Stack the quarters to form a (4, 64, 64) array
    quarters_image = np.stack([top_left, top_right, bottom_left, bottom_right])

    # Append the quarters image to the list
    images.append(quarters_image)

# Convert the list of images into a 3D NumPy array
image_stack_masks = np.stack(images)

image_stack_masks = image_stack_masks.reshape(-1, 128, 128)


#print(image_stack_masks.shape)  # Should be (n_images, 256, 256)

#X = np.random.rand(Images, Width, Height).reshape(-1, Width * Height)
#Y = np.random.rand(Images, Width, Height).reshape(-1, Width * Height)
#print(f'Shape of Flattened Image:', X.shape)

print(image_stack.shape)

image_stack = image_stack.reshape(-1, image_stack.shape[1] * image_stack.shape[2])
image_stack_masks = image_stack_masks.reshape(-1, image_stack_masks.shape[1] * image_stack_masks.shape[2])

print(image_stack.shape)

X_tensor = torch.tensor(image_stack, dtype=torch.float32)
Y_tensor = torch.tensor(image_stack_masks, dtype=torch.float32)

dataset = TensorDataset(X_tensor, Y_tensor)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

(112, 128, 128)
(112, 16384)


Instaniate Class (Parameters)

In [4]:
# Instantiate Model Class
transformer_model = Transformer(input_dim=16384, hidden_dim=10, num_heads=4, output_dim=16384)

# Set Optimizer and Training Loss
loss = nn.MSELoss()
optimizer = torch.optim.Adam(transformer_model.parameters(), lr=0.001)

# Training Parameters
num_epochs = 100

In [5]:
print('hello')

hello


In [6]:
# Print Model Parameters
for name, param in transformer_model.named_parameters():
    print(f"Parameter name: {name}")
    print(f"Shape: {param.shape}")

Parameter name: multihead_attn.in_proj_weight
Shape: torch.Size([49152, 16384])
Parameter name: multihead_attn.in_proj_bias
Shape: torch.Size([49152])
Parameter name: multihead_attn.out_proj.weight
Shape: torch.Size([16384, 16384])
Parameter name: multihead_attn.out_proj.bias
Shape: torch.Size([16384])


Train Model

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
transformer_model.to(device)

cuda


Transformer(
  (multihead_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=16384, out_features=16384, bias=True)
  )
)

In [None]:
# Model Training

for epoch in range(num_epochs):
    transformer_model.train()
    epoch_loss = 0.0

    for batch_idx, (inputs, targets) in enumerate(dataloader):
        inputs, targets = inputs.to(device), targets.to(device)
        inputs = inputs.view(inputs.size(0), -1, 16384)  # Reshape for input_dim=100

        optimizer.zero_grad()
        outputs, _ = transformer_model(inputs)
        loss_value = loss(outputs, targets)
        loss_value.backward()
        optimizer.step()

        epoch_loss += loss_value.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss / len(dataloader):.4f}")



  return F.mse_loss(input, target, reduction=self.reduction)


Epoch [1/100], Loss: 176366627515.1250
Epoch [2/100], Loss: 109546750.0134
Epoch [3/100], Loss: 10873376.6295
Epoch [4/100], Loss: 9834765.7720
Epoch [5/100], Loss: 9904183.6116
Epoch [6/100], Loss: 6998101.3742
Epoch [7/100], Loss: 6934522.3583
Epoch [8/100], Loss: 4873388.0056
Epoch [9/100], Loss: 3698774.4464
Epoch [10/100], Loss: 3393951.6722
Epoch [11/100], Loss: 2993138.4752
Epoch [12/100], Loss: 2666033.9104
Epoch [13/100], Loss: 1910382.5156
Epoch [14/100], Loss: 1786948.7335
Epoch [15/100], Loss: 1492270.9060
Epoch [16/100], Loss: 1280306.1719


Model Testing