Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 38 additions & 7 deletions lars/nepho/inference.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,57 @@
import asyncio
DEFAULT_CATEGORIES = ["NO PRECIPITATION", "STRATIFORM RAIN", "SNOW", "SCATTERED CONVECTION",
"LINEAR CONVECTION", "SUPERCELLS", "UNKNOWN"]
DEFAULT_CATEGORIES = {"No precipitation": "No echoes greater than 10 dBZ present. A circle of echoes near radar site may be present due to ground clutter.",
"Stratiform rain": "Widespread echoes between 0 and 35 dBZ, not present as a circular pattern around the radar site.",
"Scattered Convection": "Present as isolated to scattered cells with reflectivities between 35-65 dBZ",
"Linear convection": "Cells must be organized into a linear structure with reflectivities between 40-60 dBZ",
"Supercells": "Supercells contain the classic hook echo and bounded weak echo region signatures with reflectivities above 55 dBZ",
"Unknown": "If you cannot confidently classify the radar image into one of the above categories"}

async def label_radar_data(radar_df, model, categories=None):
async def label_radar_data(radar_df, model, categories=None, site="Bankhead National Forest",
verbose=True, vmin=-20, vmax=60):
"""
Label radar data using a given model.

Parameters
----------
radar_df (pd.DataFrame): DataFrame containing radar data to be labeled.
model: Model used for labeling the radar data.
site: str: Radar site identifier.

Returns
-------
pd.DataFrame
DataFrame containing the labeled radar data.
"""
if categories is None:
categories = DEFAULT_CATEGORIES
prompt = "This is an image of weather radar base reflectivity data." \
f" The radar site is the ARM Facility {site} site." \
" Please classify the weather depicted into one of the following categories: " \
f"{', '.join(categories) if categories else ', '.join(DEFAULT_CATEGORIES)}."
f"{', '.join(categories) if categories else ', '.join(categories)}."
prompt += "Each category is defined as follows: "
for category, description in categories.items():
prompt += f"{category}: {description}; "
prompt += f"The reflectivity values range from {vmin} dBZ as indicated by the blue colors to {vmax} dBZ as indicated by the red colors."
radar_df["llm_label"] = ""

for fi in radar_df["file_path"].values:
output = await model.chat(prompt, images=[fi])
print(output)
radar_df.loc[radar_df["file_path"] == fi, "label"] = output
time = radar_df.loc[radar_df["file_path"] == fi, "time"].values[0]
prompt_with_time = prompt + f"Please provide just the category label for the radar image taken at time {time}."


output_model = await model.chat(prompt_with_time, images=[fi])
# Find the category label in the output
output = output_model.strip()
for category in categories.keys():
if category.lower() in output.lower():
output = category
break
if verbose:
print("Category assigned:", output)
print("Model output:", output_model)
if output[-1] == ".":
output = output[:-1]
radar_df.loc[radar_df["file_path"] == fi, "llm_label"] = output.strip()


return radar_df
22 changes: 15 additions & 7 deletions lars/nepho/models/ollama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,22 @@ async def chat(self, prompt: str, images: Optional[List[str]] = None) -> str:

image_data = self.encode_image(image_path)
images_data.append(image_data)

#if self.model_name == "llama4:scout":
# payload = {
# "model": self.model_name,
# "messages": [
# {"role": "user", "content": prompt, "images": images_data}
# ],
# "stream": False
#}
#else:
payload = {
"model": self.model_name,
"prompt": prompt,
"images": images_data,
"stream": False
"model": self.model_name,
"prompt": prompt,
"images": images_data,
"stream": False
}

# Use generate endpoint for vision models
url = self.api_url
else:
Expand Down Expand Up @@ -107,7 +115,7 @@ async def chat(self, prompt: str, images: Optional[List[str]] = None) -> str:

def supports_vision(self) -> bool:
"""Check if this model supports vision capabilities."""
vision_models = ["llava", "bakllava", "moondream", "minicpm-v", "llava-llama2", "llava-llama3"]
vision_models = ["llava", "bakllava", "moondream", "minicpm-v", "llava-llama2", "llava-llama3", "llama4:scout"]
return any(vision_model in self.model_name.lower() for vision_model in vision_models)

async def list_available_models(self) -> List[str]:
Expand Down
28 changes: 17 additions & 11 deletions lars/preprocessing/radar_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def preprocess_radar_data(file_path, output_path,
"""

file_list = glob.glob(file_path + '/*.nc')
out_df = pd.DataFrame(columns=['file_path', 'time', 'label'])
out_df = pd.DataFrame(columns=['file_path', 'time', 'label', 'ref_min', 'ref_max'])
if not "vmin" in kwargs:
kwargs['vmin'] = -20
if not "vmax" in kwargs:
Expand All @@ -57,26 +57,32 @@ def preprocess_radar_data(file_path, output_path,
if sweep["sweep_mode"] == 'ppi' or sweep["sweep_mode"] == 'sector':
fig = plt.figure(figsize=(4, 4))
ax = plt.axes()
sweep["corrected_reflectivity"].plot(x="x", y="y",
add_colorbar=False,
ax=ax,
**kwargs)
sweep["corrected_reflectivity"].where(
sweep["corrected_reflectivity"] > min_ref).plot(x="x", y="y",
ax=ax,
**kwargs)
min_ref = sweep["corrected_reflectivity"].where(
sweep["corrected_reflectivity"] > min_ref).values.min()
max_ref = sweep["corrected_reflectivity"].where(
sweep["corrected_reflectivity"] > min_ref).values.max()

ax.set_xlim(x_bounds)
ax.set_ylim(y_bounds)
ax.axis('off')
ax.set_title('')
ax.set_xlabel('')
ax.set_ylabel('')
ax.set_xlabel('X [m]')
ax.set_ylabel('Y [m]')
ax.set_xticks([-100000, -50000, 0, 50000, 100000])
ax.set_yticks([-100000, -50000, 0, 50000, 100000])
fig.tight_layout()
file_name = os.path.join(output_path,
os.path.basename(file).replace('.nc', '.png'))
time_str = pd.to_datetime(sweep["time"].values[0]).strftime('%Y-%m-%d %H:%M:%S')
label = "UNKNOWN" # Placeholder for actual label extraction logic
fig.savefig(os.path.join(output_path,
os.path.basename(file).replace('.nc', '.png')),
dpi=100)
dpi=150)
plt.close(fig)
out_df.loc[len(out_df)] = [file_name, time_str, label]
out_df.loc[len(out_df)] = [file_name, time_str, label, min_ref, max_ref]

else:
print(f"Sweep mode is not PPI or sector scan in {file}, skipping.")
else:
Expand Down
Loading