In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from torchsummary import summary
from PIL import Image
import os
import pandas as pd
from sklearn.model_selection import train_test_split

In [2]:
class HTSDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.tif')]
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.image_dir, img_name)
        image = Image.open(img_path).convert('L')  # Convert to grayscale

        # Load corresponding label
        label_name = os.path.splitext(img_name)[0] + '.txt'
        label_path = os.path.join(self.label_dir, label_name)
        with open(label_path, 'r') as file:
            line = file.readline().strip().split('\t')
            critical_current = float(line[0])
            thickness = float(line[1])

        if self.transform:
            image = self.transform(image)

        return image, torch.tensor([critical_current, thickness], dtype=torch.float32)

# Example of how to use the dataset
transform = transforms.Compose([transforms.ToTensor()])
dataset = HTSDataset(image_dir='ProcessedImages/10kx', label_dir='Ic', transform=transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)


In [3]:
class HTSModel(nn.Module):
    def __init__(self):
        super(HTSModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)

        # Fully connected layers for image features
        self.fc1 = nn.Linear(64 * 64 * 64, 128)  # Image features to 128 dimensions

        # Fully connected layers for thickness
        self.fc_thickness = nn.Linear(1, 16)  # Simple layer for thickness

        # Combined image features and thickness
        self.fc2 = nn.Linear(128 + 16, 64)  # Combine image features and thickness
        self.fc3 = nn.Linear(64, 1)  # Output critical current

    def forward(self, x, thickness):
        # CNN part: Feature extraction from the image
        x = self.pool(torch.relu(self.conv1(x)))  # Conv1 + Pooling
        x = self.pool(torch.relu(self.conv2(x)))  # Conv2 + Pooling
        
        # Flatten the CNN output
        x = x.view(x.size(0), -1)  # Flatten image features

        # Pass image features through first fully connected layer
        x = torch.relu(self.fc1(x))

        # Process thickness through a simple FC layer
        thickness = torch.relu(self.fc_thickness(thickness))  # Thickness to 16 dimensions

        # Concatenate image features with processed thickness
        combined = torch.cat((x, thickness), dim=1)

        # Pass combined features through remaining fully connected layers
        x = torch.relu(self.fc2(combined))
        x = self.fc3(x)  # Output critical current
        return x


# Instantiate the model
model = HTSModel()

In [4]:
# Create an instance of the model
model = HTSModel()

# Print summary (assuming input image size of 64x64)
summary(model, [(1, 256, 256), (1, 1)])  # Image input size and thickness input size

ExecutableNotFound: failed to execute WindowsPath('dot'), make sure the Graphviz executables are on your systems' PATH

In [9]:
def train_model(model, dataloader, criterion, optimizer, num_epochs):
    for epoch in range(num_epochs):
        running_loss = 0.0
        for inputs, labels in dataloader:
            images, thickness = inputs, labels[:, 1].unsqueeze(1)
            critical_current = labels[:, 0].unsqueeze(1)
            
            optimizer.zero_grad()
            outputs = model(images, thickness)
            loss = criterion(outputs, critical_current)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(dataloader)}')

# Loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model
train_model(model, dataloader, criterion, optimizer, num_epochs=57)


Epoch 1/57, Loss: 15602.43664426389
Epoch 2/57, Loss: 2387.0061094864554
Epoch 3/57, Loss: 248.50210057134214
Epoch 4/57, Loss: 103.83615431578264
Epoch 5/57, Loss: 61.68021293308424
Epoch 6/57, Loss: 51.24359470865001
Epoch 7/57, Loss: 46.48269114287003
Epoch 8/57, Loss: 33.00340810029403
Epoch 9/57, Loss: 22.194990821506668
Epoch 10/57, Loss: 21.724803468455438
Epoch 11/57, Loss: 16.25571170060531
Epoch 12/57, Loss: 15.147523797076682
Epoch 13/57, Loss: 10.722421116155127
Epoch 14/57, Loss: 6.382235444110373
Epoch 15/57, Loss: 5.327297563138216
Epoch 16/57, Loss: 5.209540325662364
Epoch 17/57, Loss: 9.743270874023438
Epoch 18/57, Loss: 8.909779815570168
Epoch 19/57, Loss: 3.896741053332453
Epoch 20/57, Loss: 3.363451651904894
Epoch 21/57, Loss: 2.8419046816618545
Epoch 22/57, Loss: 3.4648740887641907
Epoch 23/57, Loss: 3.108551605888035
Epoch 24/57, Loss: 1.6863784556803496
Epoch 25/57, Loss: 1.5425043352272199
Epoch 26/57, Loss: 1.9860107639561528
Epoch 27/57, Loss: 1.93735610920449

In [15]:
from PIL import Image

def load_and_preprocess_image(image_path, transform=None):
    image = Image.open(image_path).convert('L')  # Convert to grayscale
    if transform:
        image = transform(image)
    image = image.unsqueeze(0)  # Add batch dimension
    return image

# Example usage
image_path = 'Test/test_1.tif'  # Replace with the path to your image
image = load_and_preprocess_image(image_path, transform=transform)

thickness_value = 3.6  # Replace with the actual thickness value
thickness_tensor = torch.tensor([[thickness_value]], dtype=torch.float32)

model.eval()  # Set the model to evaluation mode
with torch.no_grad():
    predicted_current = model(image, thickness_tensor)
    
print(f"Predicted Critical Current: {int(predicted_current.item())}")


Predicted Critical Current: 599


In [14]:
torch.save(model.state_dict(), "10kx_weights.pth")