In [1]:
import torch
from torchvision import transforms
from PIL import Image

# Function to load and preprocess the image
def load_image_as_tensor(image_path, batch_size=1):
    # Open the image using PIL
    image = Image.open(image_path)
    
    # Define the transformation pipeline
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),  # Convert PIL Image to tensor and scale to [0, 1]
        # transforms.Lambda(lambda x: x.unsqueeze(0).repeat(batch_size, 1, 1, 1))  # Add batch dimension and repeat
    ])
    
    # Apply the transformation
    tensor = transform(image)
    
    return tensor



In [2]:
image = load_image_as_tensor("examples/006.jpg")

In [3]:
image.shape

torch.Size([3, 256, 256])

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
scale_max = 2

In [5]:
model_path = "models/anysr/anysr_edsr_500.pth"

In [6]:
model_load = torch.load(model_path)['model']
model_args = model_load['args']['encoder_spec']['args']
pretrain_dict = model_load['sd']

In [7]:
image_tensor = image.squeeze(dim=0).to(device)

In [8]:
image_tensor.shape

torch.Size([3, 256, 256])

In [9]:
from anysr_module.edsr_anysr import make_edsr


model = make_edsr(scale=2, **model_args)

In [10]:
from anysr_module.anysr_model import replace


model_dict = model.state_dict()
model_dict = replace(pretrain_dict, model_dict)
model.load_state_dict(model_dict, strict=False)

In [11]:
model = model.to(device)

In [12]:
model.scale = 2
model.scale2 = 2

In [14]:
image_tensor.shape

torch.Size([3, 256, 256])

In [12]:
pred = model(
    image_tensor.unsqueeze(dim=0).cuda(),
)

In [16]:
pred.shape

torch.Size([1, 256, 256, 256])

In [13]:
pred.shape

torch.Size([1, 256, 256, 256])

In [14]:
reshaped_tensor = pred.reshape(1, 256, 256, 3, 9, 9)

RuntimeError: shape '[1, 256, 256, 3, 9, 9]' is invalid for input of size 16777216

In [17]:
pred

tensor([[[[-0.2023, -0.2465, -0.1883,  ...,  0.0145,  0.0291,  0.0131],
          [-0.2058, -0.3171, -0.2691,  ...,  0.0181,  0.0362,  0.0219],
          [-0.1518, -0.3128, -0.2633,  ...,  0.0039,  0.0277,  0.0250],
          ...,
          [-0.0476,  0.0612,  0.0954,  ..., -0.1948, -0.2188, -0.1279],
          [-0.0426,  0.0376,  0.0445,  ..., -0.1631, -0.1892, -0.0890],
          [ 0.0051,  0.0672,  0.0702,  ..., -0.0010, -0.0048,  0.0314]],

         [[-0.1818, -0.2641, -0.2487,  ..., -0.1725, -0.1746, -0.1650],
          [-0.0194, -0.1230, -0.1304,  ..., -0.2200, -0.2187, -0.1621],
          [-0.0632, -0.1623, -0.1519,  ..., -0.2183, -0.2122, -0.1603],
          ...,
          [-0.1712, -0.1563, -0.1632,  ..., -0.1246, -0.1534, -0.1306],
          [-0.2195, -0.2493, -0.2216,  ..., -0.1596, -0.1834, -0.1382],
          [-0.1708, -0.2002, -0.1785,  ..., -0.1240, -0.1355, -0.1496]],

         [[ 0.0384, -0.1395, -0.1434,  ...,  0.0056, -0.0025, -0.0254],
          [ 0.0819, -0.0770, -