In [21]:
import torch
import torchvision
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import SGD

# Load the image
data = "./components/download.jpg"
image_tensor = torchvision.io.read_image(data)

# Normalize the tensor to have values between 0 and 1 if it's in uint8
if image_tensor.dtype == torch.uint8:
    image_tensor = image_tensor.float() / 255

# Define the transformations
# Note: Ensure transforms.ToPILImage() is only used if necessary for compatibility with Resize
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize the image to 224x224
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Apply the transformation
# If the image tensor is not a PIL Image, convert it before applying Resize
if not isinstance(image_tensor, torch.FloatTensor):  # Check if it's already a FloatTensor
    image_tensor = transforms.ToPILImage()(image_tensor)
image_tensor = transform(image_tensor)

# Check the shape
print(image_tensor.shape)  # Should be [3, 224, 224]

# Add a batch dimension if necessary
# if image_tensor.dim() == 3:
#     image_tensor = image_tensor.unsqueeze(0)  # Should now be [1, 3, 224, 224]

# Check the final shape
print(image_tensor.shape)  # Should be [1, 3, 224, 224]


torch.Size([3, 224, 224])
torch.Size([1, 3, 224, 224])


In [28]:
print(image_tensor.shape)
class SimpleDataset(Dataset):
    def __init__(self, image):
        self.image = image

    def __len__(self):
        return 1  # We have only one image

    def __getitem__(self, idx):
        return self.image

dataset = SimpleDataset(image_tensor)
dataloader = DataLoader(dataset)

# Define the model
model = models.resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, 2)  # Assuming 2 classes

# Define the loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print(image_tensor.shape)

# Run the model on the image (for inference)
model.eval()  # Set the model to evaluation mode
# with torch.no_grad():

#     outputs = model(image_tensor)
#     print(outputs)
# print(image_tensor.shape)

# Train the model on the image
num_epochs = 5  # Number of epochs
model.train()  # Set the model to training mode
for epoch in range(num_epochs):
    for inputs in dataloader:
        print(inputs.shape)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, torch.tensor([1]).long())  # assuming the class label is 1
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')


torch.Size([1, 3, 224, 224])




torch.Size([1, 3, 224, 224])
tensor([[ 0.7066, -1.0179]])
torch.Size([1, 3, 224, 224])
torch.Size([1, 1, 3, 224, 224])


RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [1, 1, 3, 224, 224]