diff --git a/crates/ruff/src/jupyter/notebook.rs b/crates/ruff/src/jupyter/notebook.rs index d6e5885e56776..37d372a65d5f1 100644 --- a/crates/ruff/src/jupyter/notebook.rs +++ b/crates/ruff/src/jupyter/notebook.rs @@ -1,4 +1,5 @@ use std::cmp::Ordering; +use std::fmt::Display; use std::fs::File; use std::io::{BufReader, BufWriter, Read, Seek, SeekFrom, Write}; use std::iter; @@ -41,6 +42,20 @@ pub fn round_trip(path: &Path) -> anyhow::Result { Ok(String::from_utf8(writer)?) } +impl Display for SourceValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SourceValue::String(string) => f.write_str(string), + SourceValue::StringArray(string_array) => { + for string in string_array { + f.write_str(string)?; + } + Ok(()) + } + } + } +} + impl Cell { /// Return the [`SourceValue`] of the cell. fn source(&self) -> &SourceValue { @@ -407,6 +422,11 @@ impl Notebook { self.content = transformed.to_string(); } + /// Return a slice of [`Cell`] in the Jupyter notebook. + pub fn cells(&self) -> &[Cell] { + &self.raw.cells + } + /// Return `true` if the notebook is a Python notebook, `false` otherwise. pub fn is_python_notebook(&self) -> bool { self.raw diff --git a/crates/ruff_cli/src/diagnostics.rs b/crates/ruff_cli/src/diagnostics.rs index 5e494a0482713..47217002981a2 100644 --- a/crates/ruff_cli/src/diagnostics.rs +++ b/crates/ruff_cli/src/diagnostics.rs @@ -17,7 +17,7 @@ use similar::TextDiff; use std::os::unix::fs::PermissionsExt; use crate::cache::Cache; -use ruff::jupyter::Notebook; +use ruff::jupyter::{Cell, Notebook}; use ruff::linter::{lint_fix, lint_only, FixTable, FixerResult, LinterResult}; use ruff::logging::DisplayParseError; use ruff::message::Message; @@ -261,13 +261,64 @@ pub(crate) fn lint_path( } }, flags::FixMode::Diff => { - let mut stdout = io::stdout().lock(); - TextDiff::from_lines(contents.as_str(), &transformed) - .unified_diff() - .header(&fs::relativize_path(path), &fs::relativize_path(path)) - .to_writer(&mut stdout)?; - stdout.write_all(b"\n")?; - stdout.flush()?; + match &source_kind { + SourceKind::Python(_) => { + let mut stdout = io::stdout().lock(); + TextDiff::from_lines(contents.as_str(), &transformed) + .unified_diff() + .header(&fs::relativize_path(path), &fs::relativize_path(path)) + .to_writer(&mut stdout)?; + stdout.write_all(b"\n")?; + stdout.flush()?; + } + SourceKind::Jupyter(dest_notebook) => { + // We need to load the notebook again, since we might've + // mutated it. + let src_notebook = match load_jupyter_notebook(path) { + Ok(notebook) => notebook, + Err(diagnostic) => return Ok(*diagnostic), + }; + let mut stdout = io::stdout().lock(); + for ((idx, src_cell), dest_cell) in src_notebook + .cells() + .iter() + .enumerate() + .zip(dest_notebook.cells().iter()) + { + let (Cell::Code(src_code_cell), Cell::Code(dest_code_cell)) = (src_cell, dest_cell) else { + continue; + }; + TextDiff::from_lines( + &src_code_cell.source.to_string(), + &dest_code_cell.source.to_string(), + ) + .unified_diff() + // Jupyter notebook cells don't necessarily have a newline + // at the end. For example, + // + // ```python + // print("hello") + // ``` + // + // For a cell containing the above code, there'll only be one line, + // and it won't have a newline at the end. If it did, there'd be + // two lines, and the second line would be empty: + // + // ```python + // print("hello") + // + // ``` + .missing_newline_hint(false) + .header( + &format!("{}:cell {}", &fs::relativize_path(path), idx), + &format!("{}:cell {}", &fs::relativize_path(path), idx), + ) + .to_writer(&mut stdout)?; + } + stdout.write_all(b"\n")?; + stdout.flush()?; + } + } } flags::FixMode::Generate => {} }