In [1]:
# Importing dependencies
import torch
from PIL import Image
from torch import nn,save,load
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [3]:
# Loading Data
#transform = transforms.Compose([transforms.ToTensor()])
#train_dataset = datasets.MNIST(root="data", download=True, train=True, transform=transform)
#train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # MRT ist meist in Graustufen
    transforms.Resize((128, 128)),  # Größe nach Bedarf anpassen
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
train_dataset = datasets.ImageFolder(root="archive/Training", transform=transform)      #pituitary aus Datensatz rausgelöscht
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)


In [4]:
# Define the image classifier model
class ImageClassifier(nn.Module):
    def __init__(self):
        super(ImageClassifier, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 64, kernel_size=3), 
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        dummy_input = torch.randn(1, 1, 128, 128)
        dummy_output = self.conv_layers(dummy_input)
        flatten_size = dummy_output.view(1, -1).shape[1]

        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(flatten_size, 10)
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.fc_layers(x)
        return x

In [8]:
# Create an instance of the image classifier model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
classifier = ImageClassifier().to(device)

cpu


In [6]:
# Define the optimizer and loss function
optimizer = Adam(classifier.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

In [None]:
# Train the model
for epoch in range(10):  # Train for 10 epochs-> 100
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()  # Reset gradients ----------------------- ?
        outputs = classifier(images)  # Forward pass
        loss = loss_fn(outputs, labels)  # Compute loss
        loss.backward()  # Backward pass
        optimizer.step()  # Update weights
    print(f"Epoch:{epoch} loss is {loss.item()}")


Epoch:0 loss is 2.311532497406006


KeyboardInterrupt: 

In [None]:
# Save the trained model
torch.save(classifier.state_dict(), 'model_state.pt')

In [None]:
# Load the saved model
with open('model_state.pt', 'rb') as f: 
     classifier.load_state_dict(load(f))  
       

In [10]:
# Perform inference on an image
#img = Image.open('image.jpg')
#img_transform = transforms.Compose([transforms.ToTensor()])
#img_tensor = img_transform(img).unsqueeze(0).to(device)
#output = classifier(img_tensor)
#predicted_label = torch.argmax(output)
#print(f"Predicted label: {predicted_label}")

img = Image.open('/nfs/homes/sdreyer/Digit-Classification-Pytorch/archive/Testing/glioma/Te-gl_0010.jpg')
img_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # MRT ist meist in Graustufen
    transforms.Resize((128, 128)),  # Größe nach Bedarf anpassen
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

img_tensor = img_transform(img).unsqueeze(0).to(device)
#output = classifier(img_tensor)
#predicted_label = torch.argmax(output)
#class_names = train_dataset.classes
#
#
#print(f"Predicted label: {predicted_label}")
#print(f"Predicted class: {class_names}")

with torch.no_grad():
    output = classifier(img_tensor)
    predicted_index = torch.argmax(output).item()

# Klassennamen laden
class_names = train_dataset.classes  
predicted_label = class_names[predicted_index]
print(f"Predicted class: {predicted_label}")

Predicted class: meningioma
